From 28fe26637b3d94b45444cbe5fa9cab921257cf09 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:01:07 +0900 Subject: [PATCH 01/12] [client] Fix Windows installer upgrade detection for pre-0.70.1 installs (#6025) --- client/installer.nsis | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/client/installer.nsis b/client/installer.nsis index 8b2b8ea39..6b8d3258e 100644 --- a/client/installer.nsis +++ b/client/installer.nsis @@ -200,9 +200,17 @@ Pop $0 !macroend Function .onInit -SetRegView 64 StrCpy $INSTDIR "${INSTALL_DIR}" + +; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live +; in the 32-bit view. Fall back to it so upgrades still find them. +SetRegView 64 ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" +${If} $R0 == "" + SetRegView 32 + ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" + SetRegView 64 +${EndIf} ${If} $R0 != "" # if silent install jump to uninstall step IfSilent uninstall From 7eba5dafd8bbba9e5a0c4e8bd34d14dfda565db9 Mon Sep 17 00:00:00 2001 From: Nicolas Frati Date: Wed, 29 Apr 2026 11:28:55 +0200 Subject: [PATCH 02/12] [misc] Add comment automation on release workflow for PRs (#6016) * feat: add comment automation on release workflow for PRs * update action permissions --- .github/workflows/release.yml | 156 ++++++++++++++++++++++++++++++++-- 1 file changed, 150 insertions(+), 6 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 826c05ff3..081bcafc4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,6 +115,12 @@ jobs: release: runs-on: ubuntu-latest-m + outputs: + release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }} + linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }} + windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }} + macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }} + ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }} env: flags: "" steps: @@ -213,10 +219,13 @@ jobs: if: always() run: rm -f /tmp/gpg-rpm-signing-key.asc - name: Tag and push images (amd64 only) + id: tag_and_push_images if: | (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || (github.event_name == 'push' && github.ref == 'refs/heads/main') run: | + set -euo pipefail + resolve_tags() { if [[ "${{ github.event_name }}" == "pull_request" ]]; then echo "pr-${{ github.event.pull_request.number }}" @@ -225,6 +234,17 @@ jobs: fi } + ghcr_package_url() { + local image="$1" package encoded_package + package="${image#ghcr.io/}" + package="${package#*/}" + package="${package%%:*}" + encoded_package="${package//\//%2F}" + echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}" + } + + image_refs=() + tag_and_push() { local src="$1" img_name tag dst img_name="${src%%:*}" @@ -233,35 +253,56 @@ jobs: echo "Tagging ${src} -> ${dst}" docker tag "$src" "$dst" docker push "$dst" + image_refs+=("$dst") done } - export -f tag_and_push resolve_tags + cat > /tmp/goreleaser-artifacts.json <<'JSON' + ${{ steps.goreleaser.outputs.artifacts }} + JSON - echo '${{ steps.goreleaser.outputs.artifacts }}' | \ - jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \ - grep '^ghcr.io/' | while read -r SRC; do - tag_and_push "$SRC" - done + mapfile -t src_images < <( + jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json + ) + + for src in "${src_images[@]}"; do + tag_and_push "$src" + done + + { + echo "images_markdown<> "$GITHUB_OUTPUT" - name: upload non tags for debug purposes + id: upload_release uses: actions/upload-artifact@v4 with: name: release path: dist/ retention-days: 7 - name: upload linux packages + id: upload_linux_packages uses: actions/upload-artifact@v4 with: name: linux-packages path: dist/netbird_linux** retention-days: 7 - name: upload windows packages + id: upload_windows_packages uses: actions/upload-artifact@v4 with: name: windows-packages path: dist/netbird_windows** retention-days: 7 - name: upload macos packages + id: upload_macos_packages uses: actions/upload-artifact@v4 with: name: macos-packages @@ -270,6 +311,8 @@ jobs: release_ui: runs-on: ubuntu-latest + outputs: + release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }} steps: - name: Parse semver string id: semver_parser @@ -360,6 +403,7 @@ jobs: if: always() run: rm -f /tmp/gpg-rpm-signing-key.asc - name: upload non tags for debug purposes + id: upload_release_ui uses: actions/upload-artifact@v4 with: name: release-ui @@ -368,6 +412,8 @@ jobs: release_ui_darwin: runs-on: macos-latest + outputs: + release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }} steps: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV @@ -402,12 +448,110 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: upload non tags for debug purposes + id: upload_release_ui_darwin uses: actions/upload-artifact@v4 with: name: release-ui-darwin path: dist/ retention-days: 3 + comment_release_artifacts: + name: Comment release artifacts + runs-on: ubuntu-latest + needs: [release, release_ui, release_ui_darwin] + if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }} + permissions: + contents: read + issues: write + pull-requests: write + steps: + - name: Create or update PR comment + uses: actions/github-script@v7 + env: + RELEASE_RESULT: ${{ needs.release.result }} + RELEASE_UI_RESULT: ${{ needs.release_ui.result }} + RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }} + RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }} + LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }} + WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }} + MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }} + RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }} + RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }} + GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }} + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const marker = ''; + const { owner, repo } = context.repo; + const issue_number = context.payload.pull_request.number; + const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`; + const shortSha = context.payload.pull_request.head.sha.slice(0, 7); + + const artifactCell = (url, result) => { + if (url) return `[Download](${url})`; + return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_'; + }; + + const artifacts = [ + ['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT], + ['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT], + ['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT], + ]; + + const artifactRows = artifacts + .map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`) + .join('\n'); + + const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._'; + + const body = [ + marker, + '## Release artifacts', + '', + `Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`, + '', + '| Artifact | Link |', + '| --- | --- |', + artifactRows, + '', + '### GHCR images (amd64)', + ghcrImages, + '', + '_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._', + ].join('\n'); + + const comments = await github.paginate(github.rest.issues.listComments, { + owner, + repo, + issue_number, + per_page: 100, + }); + + const previous = comments.find(comment => + comment.user?.type === 'Bot' && comment.body?.includes(marker) + ); + + if (previous) { + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: previous.id, + body, + }); + core.info(`Updated release artifacts comment ${previous.id}`); + } else { + const { data } = await github.rest.issues.createComment({ + owner, + repo, + issue_number, + body, + }); + core.info(`Created release artifacts comment ${data.id}`); + } + trigger_signer: runs-on: ubuntu-latest needs: [release, release_ui, release_ui_darwin] From ad93dcf9807e46ac648ff67b0ab994696b8cb6fc Mon Sep 17 00:00:00 2001 From: shuuri-labs <61762328+shuuri-labs@users.noreply.github.com> Date: Wed, 29 Apr 2026 13:14:46 +0200 Subject: [PATCH 03/12] [client] Enable UI autostart for silent and MSI installs (#6026) * fix(client): enable UI autostart for silent and MSI installs The MSI installer had no autostart logic and the EXE silent installer skipped the autostart page, leaving the registry entry unwritten. This caused the NetBird UI tray to not start at login after RMM deployments. Add an AUTOSTART property (default: 1) to the MSI that writes the HKLM Run key, and initialize AutostartEnabled in the NSIS .onInit so silent installs match the interactive default. * add real guid for NetBirdAutoStart component --- client/installer.nsis | 2 ++ client/netbird.wxs | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/client/installer.nsis b/client/installer.nsis index 6b8d3258e..63bff1c5b 100644 --- a/client/installer.nsis +++ b/client/installer.nsis @@ -201,6 +201,8 @@ Pop $0 Function .onInit StrCpy $INSTDIR "${INSTALL_DIR}" +; Default autostart to enabled so silent installs (/S) match the interactive default +StrCpy $AutostartEnabled "1" ; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live ; in the 32-bit view. Fall back to it so upgrades still find them. diff --git a/client/netbird.wxs b/client/netbird.wxs index 23aa250f4..2849bc6b9 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -13,6 +13,9 @@ + + + @@ -63,9 +66,21 @@ + + + + AUTOSTART = "1" + + + + + + From df197d5001c19dbeedb6e4bb44f51a6d298b3422 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 29 Apr 2026 15:04:27 +0300 Subject: [PATCH 04/12] [management] Prevent JWT reuse during peer login (#6002) --- client/cmd/testutil_test.go | 2 +- client/internal/engine_test.go | 2 +- client/server/server_test.go | 2 +- management/internals/server/boot.go | 2 +- management/internals/server/controllers.go | 7 ++ management/internals/shared/grpc/server.go | 37 +++++++++- management/server/auth/session.go | 61 ++++++++++++++++ management/server/auth/session_test.go | 82 ++++++++++++++++++++++ management/server/management_proto_test.go | 2 +- management/server/management_test.go | 1 + shared/management/client/client_test.go | 2 +- 11 files changed, 192 insertions(+), 8 deletions(-) create mode 100644 management/server/auth/session.go create mode 100644 management/server/auth/session_test.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index d7564c353..fd1007bb4 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9fa4e51b2..f4c5be70a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 772997575..54ad47e55 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, "", err } diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 2b40c0aad..f2ab0a2c4 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -173,7 +173,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider()) + srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore()) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 9a8e45d33..89bdf0abe 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" @@ -66,6 +67,12 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager { }) } +func (s *BaseServer) SessionStore() *auth.SessionStore { + return Create(s, func() *auth.SessionStore { + return auth.NewSessionStore(s.CacheStore()) + }) +} + func (s *BaseServer) AuthManager() auth.Manager { audiences := s.Config.GetAuthAudiences() audience := s.Config.HttpConfig.AuthAudience diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 6e8358f02..0c1611e7f 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + jwtv5 "github.com/golang-jwt/jwt/v5" pb "github.com/golang/protobuf/proto" // nolint "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" @@ -67,6 +68,7 @@ type Server struct { appMetrics telemetry.AppMetrics peerLocks sync.Map authManager auth.Manager + sessionStore *auth.SessionStore logBlockedPeers bool blockPeersWithSameConfig bool @@ -98,6 +100,7 @@ func NewServer( integratedPeerValidator integrated_validator.IntegratedValidator, networkMapController network_map.Controller, oAuthConfigProvider idp.OAuthConfigProvider, + sessionStore *auth.SessionStore, ) (*Server, error) { if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams @@ -140,6 +143,7 @@ func NewServer( integratedPeerValidator: integratedPeerValidator, networkMapController: networkMapController, oAuthConfigProvider: oAuthConfigProvider, + sessionStore: sessionStore, loginFilter: newLoginFilter(), @@ -535,7 +539,7 @@ func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID st log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key) } -func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { +func (s *Server) validateToken(ctx context.Context, peerKey, jwtToken string) (string, error) { if s.authManager == nil { return "", status.Errorf(codes.Internal, "missing auth manager") } @@ -545,6 +549,10 @@ func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, er return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } + if err := s.claimLoginToken(ctx, peerKey, jwtToken, token); err != nil { + return "", err + } + // we need to call this method because if user is new, we will automatically add it to existing or create a new account accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth) if err != nil { @@ -828,6 +836,31 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne return loginResp, nil } +func (s *Server) claimLoginToken(ctx context.Context, peerKey, jwtToken string, token *jwtv5.Token) error { + if s.sessionStore == nil || token == nil { + return nil + } + + exp, err := token.Claims.GetExpirationTime() + if err != nil || exp == nil { + log.WithContext(ctx).Warnf("JWT has no usable exp claim for peer %s", peerKey) + return status.Error(codes.Unauthenticated, "jwt token has no expiration") + } + + err = s.sessionStore.RegisterToken(ctx, jwtToken, exp.Time) + if err == nil { + return nil + } + + if errors.Is(err, auth.ErrTokenAlreadyUsed) || errors.Is(err, auth.ErrTokenExpired) { + log.WithContext(ctx).Warnf("%v for peer %s", err, peerKey) + return status.Error(codes.Unauthenticated, err.Error()) + } + + log.WithContext(ctx).Warnf("failed to claim JWT for peer %s: %v", peerKey, err) + return status.Error(codes.Unavailable, "failed to claim jwt token") +} + // processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if // the token is valid. // @@ -838,7 +871,7 @@ func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginReque if loginReq.GetJwtToken() != "" { var err error for i := 0; i < 3; i++ { - userID, err = s.validateToken(ctx, loginReq.GetJwtToken()) + userID, err = s.validateToken(ctx, peerKey.String(), loginReq.GetJwtToken()) if err == nil { break } diff --git a/management/server/auth/session.go b/management/server/auth/session.go new file mode 100644 index 000000000..7621a1c10 --- /dev/null +++ b/management/server/auth/session.go @@ -0,0 +1,61 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" +) + +const ( + usedTokenKeyPrefix = "jwt-used:" + usedTokenMarker = "1" +) + +var ( + ErrTokenAlreadyUsed = errors.New("JWT already used") + ErrTokenExpired = errors.New("JWT expired") +) + +type SessionStore struct { + cache *cache.Cache[string] +} + +func NewSessionStore(cacheStore store.StoreInterface) *SessionStore { + return &SessionStore{cache: cache.New[string](cacheStore)} +} + +// RegisterToken records a JWT until its exp time and rejects reuse. +func (s *SessionStore) RegisterToken(ctx context.Context, token string, expiresAt time.Time) error { + ttl := time.Until(expiresAt) + if ttl <= 0 { + return ErrTokenExpired + } + + key := usedTokenKeyPrefix + hashToken(token) + _, err := s.cache.Get(ctx, key) + if err == nil { + return ErrTokenAlreadyUsed + } + + var notFound *store.NotFound + if !errors.As(err, ¬Found) { + return fmt.Errorf("failed to lookup used token entry: %w", err) + } + + if err := s.cache.Set(ctx, key, usedTokenMarker, store.WithExpiration(ttl)); err != nil { + return fmt.Errorf("failed to store used token entry: %w", err) + } + + return nil +} + +func hashToken(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:]) +} diff --git a/management/server/auth/session_test.go b/management/server/auth/session_test.go new file mode 100644 index 000000000..3a7d85f4c --- /dev/null +++ b/management/server/auth/session_test.go @@ -0,0 +1,82 @@ +package auth + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcache "github.com/netbirdio/netbird/management/server/cache" +) + +func newTestSessionStore(t *testing.T) *SessionStore { + t.Helper() + cacheStore, err := nbcache.NewStore(context.Background(), time.Hour, time.Hour, 100) + require.NoError(t, err) + return NewSessionStore(cacheStore) +} + +func TestSessionStore_FirstRegisterSucceeds(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + + require.NoError(t, s.RegisterToken(ctx, "token", time.Now().Add(time.Hour))) +} + +func TestSessionStore_RegisterSameTokenTwiceIsRejected(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + exp := time.Now().Add(time.Hour) + + require.NoError(t, s.RegisterToken(ctx, token, exp)) + + err := s.RegisterToken(ctx, token, exp) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTokenAlreadyUsed) +} + +func TestSessionStore_RegisterDifferentTokensAreIndependent(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + exp := time.Now().Add(time.Hour) + + require.NoError(t, s.RegisterToken(ctx, "tokenA", exp)) + require.NoError(t, s.RegisterToken(ctx, "tokenB", exp)) +} + +func TestSessionStore_RegisterWithPastExpiryIsRejected(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + + err := s.RegisterToken(ctx, token, time.Now().Add(-time.Second)) + require.Error(t, err) + assert.ErrorIs(t, err, ErrTokenExpired) +} + +func TestSessionStore_EntryEvictsAtTTLAndAllowsReRegistration(t *testing.T) { + s := newTestSessionStore(t) + ctx := context.Background() + token := "token" + + require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond))) + + err := s.RegisterToken(ctx, token, time.Now().Add(50*time.Millisecond)) + assert.ErrorIs(t, err, ErrTokenAlreadyUsed) + + time.Sleep(120 * time.Millisecond) + + require.NoError(t, s.RegisterToken(ctx, token, time.Now().Add(time.Hour))) +} + +func TestHashToken_StableAndDoesNotLeak(t *testing.T) { + a := hashToken("tokenA") + b := hashToken("tokenB") + assert.Equal(t, a, hashToken("tokenA"), "hash must be deterministic") + assert.NotEqual(t, a, b, "different tokens must hash differently") + assert.Len(t, a, 64, "sha256 hex must be 64 chars") + assert.NotContains(t, a, "tokenA", "raw token must not appear in hash") +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 18d85315d..1b77ea335 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -391,7 +391,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 3ac28cd4a..f1d49193c 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -256,6 +256,7 @@ func startServer( server.MockIntegratedValidator{}, networkMapController, nil, + nil, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d9a1a7d65..a8e8172dc 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -138,7 +138,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { t.Fatal(err) } From 11ac2af2f5130899633b31ac575a683afea7e308 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 29 Apr 2026 23:07:33 +0900 Subject: [PATCH 05/12] Use BindListener for all userspace bind in lazyconn activity (#6028) --- .../lazyconn/activity/listener_bind_test.go | 13 ------------- client/internal/lazyconn/activity/manager.go | 12 ------------ 2 files changed, 25 deletions(-) diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go index f86dd3877..1baaae6be 100644 --- a/client/internal/lazyconn/activity/listener_bind_test.go +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -3,7 +3,6 @@ package activity import ( "net" "net/netip" - "runtime" "testing" "time" @@ -18,10 +17,6 @@ import ( peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) -func isBindListenerPlatform() bool { - return runtime.GOOS == "windows" || runtime.GOOS == "js" -} - // mockEndpointManager implements device.EndpointManager for testing type mockEndpointManager struct { endpoints map[netip.Addr]net.Conn @@ -181,10 +176,6 @@ func TestBindListener_Close(t *testing.T) { } func TestManager_BindMode(t *testing.T) { - if !isBindListenerPlatform() { - t.Skip("BindListener only used on Windows/JS platforms") - } - mockEndpointMgr := newMockEndpointManager() mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} @@ -226,10 +217,6 @@ func TestManager_BindMode(t *testing.T) { } func TestManager_BindMode_MultiplePeers(t *testing.T) { - if !isBindListenerPlatform() { - t.Skip("BindListener only used on Windows/JS platforms") - } - mockEndpointMgr := newMockEndpointManager() mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 1c11378c8..cccc0669f 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -4,14 +4,12 @@ import ( "errors" "net" "net/netip" - "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" @@ -75,16 +73,6 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) return NewUDPListener(m.wgIface, peerCfg) } - // BindListener is used on Windows, JS, and netstack platforms: - // - JS: Cannot listen to UDP sockets - // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default - // gateway points to, preventing them from reaching the loopback interface. - // - Netstack: Allows multiple instances on the same host without port conflicts. - // BindListener bypasses these issues by passing data directly through the bind. - if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() { - return NewUDPListener(m.wgIface, peerCfg) - } - provider, ok := m.wgIface.(bindProvider) if !ok { return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") From ed828b7af4e25e64d5f2fdccaaa1285964e27bd8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 29 Apr 2026 23:08:47 +0900 Subject: [PATCH 06/12] Tolerate EEXIST when adding macOS scoped default routes (#6027) --- .../routemanager/systemops/systemops_darwin.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_darwin.go b/client/internal/routemanager/systemops/systemops_darwin.go index d6875ff95..3fcac4c6a 100644 --- a/client/internal/routemanager/systemops/systemops_darwin.go +++ b/client/internal/routemanager/systemops/systemops_darwin.go @@ -89,8 +89,16 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) { return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec) } + reused := false if err := r.addScopedDefault(unspec, nexthop); err != nil { - return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err) + if !errors.Is(err, unix.EEXIST) { + return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err) + } + // macOS installs its own RTF_IFSCOPE defaults for primary service + // selection on multi-NIC setups, so a route on this ifindex can + // already exist before we try. Binding to it via IP[V6]_BOUND_IF + // still produces the scoped lookup we need. + reused = true } af := unix.AF_INET @@ -102,7 +110,11 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) { if nexthop.IP.IsValid() { via = nexthop.IP.String() } - log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec)) + verb := "installed" + if reused { + verb = "reused existing" + } + log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec)) return true, nil } From 57945fc3286a4a7c7f06c688fb251e90e38bfbce Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 29 Apr 2026 17:19:22 +0200 Subject: [PATCH 07/12] [client] Trigger mobile submodule bump PRs on release tags (#6029) Trigger mobile submodule bump PRs on release tags --- .github/workflows/sync-tag.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/.github/workflows/sync-tag.yml b/.github/workflows/sync-tag.yml index 1cc553b12..a75d9a9d5 100644 --- a/.github/workflows/sync-tag.yml +++ b/.github/workflows/sync-tag.yml @@ -9,6 +9,8 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} cancel-in-progress: true +# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short +# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref. jobs: trigger_sync_tag: runs-on: ubuntu-latest @@ -20,4 +22,30 @@ jobs: ref: main repo: ${{ secrets.UPSTREAM_REPO }} token: ${{ secrets.NC_GITHUB_TOKEN }} + inputs: '{ "tag": "${{ github.ref_name }}" }' + + trigger_android_bump: + runs-on: ubuntu-latest + if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-') + steps: + - name: Trigger android-client submodule bump + uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1 + with: + workflow: bump-netbird.yml + ref: main + repo: netbirdio/android-client + token: ${{ secrets.NC_GITHUB_TOKEN }} + inputs: '{ "tag": "${{ github.ref_name }}" }' + + trigger_ios_bump: + runs-on: ubuntu-latest + if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-') + steps: + - name: Trigger ios-client submodule bump + uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1 + with: + workflow: bump-netbird.yml + ref: main + repo: netbirdio/ios-client + token: ${{ secrets.NC_GITHUB_TOKEN }} inputs: '{ "tag": "${{ github.ref_name }}" }' \ No newline at end of file From 3fc5a8d4a1fe308ff1068764a09b90b0859ab8fe Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 29 Apr 2026 23:44:38 +0200 Subject: [PATCH 08/12] [misc] fix MSI generation add installer tests (#6031) Add Windows installer build test workflow --- .github/workflows/release.yml | 149 +++++++++++++++++++++++++++++++++- client/netbird.wxs | 3 +- 2 files changed, 148 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 081bcafc4..c1ae01a98 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -114,7 +114,7 @@ jobs: retention-days: 30 release: - runs-on: ubuntu-latest-m + runs-on: ubuntu-24.04-8-core outputs: release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }} linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }} @@ -455,6 +455,151 @@ jobs: path: dist/ retention-days: 3 + test_windows_installer: + name: "Windows Installer / Build Test" + runs-on: windows-2022 + needs: [release, release_ui] + strategy: + fail-fast: false + matrix: + include: + - arch: amd64 + wintun_arch: amd64 + - arch: arm64 + wintun_arch: arm64 + defaults: + run: + shell: powershell + env: + PackageWorkdir: netbird_windows_${{ matrix.arch }} + downloadPath: '${{ github.workspace }}\temp' + steps: + - name: Parse semver string + id: semver_parser + uses: booxmedialtd/ws-action-parse-semver@v1 + with: + input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }} + version_extractor_regex: '\/v(.*)$' + + - name: Checkout + uses: actions/checkout@v4 + + - name: Add 7-Zip to PATH + run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + + - name: Download release artifacts + uses: actions/download-artifact@v4 + with: + name: release + path: release + + - name: Download UI release artifacts + uses: actions/download-artifact@v4 + with: + name: release-ui + path: release-ui + + - name: Stage binaries into dist + run: | + $workdir = "dist\${{ env.PackageWorkdir }}" + New-Item -ItemType Directory -Force -Path $workdir | Out-Null + $client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1 + $ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1 + if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 } + if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 } + Write-Host "Client: $($client.FullName)" + Write-Host "UI: $($ui.FullName)" + tar -zvxf $client.FullName -C $workdir + tar -zvxf $ui.FullName -C $workdir + Get-ChildItem $workdir + + - name: Download wintun + uses: carlosperate/download-file-action@v2 + id: download-wintun + with: + file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip + file-name: wintun.zip + location: ${{ env.downloadPath }} + sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51' + + - name: Decompress wintun files + run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }} + + - name: Move wintun.dll into dist + run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\ + + - name: Download Mesa3D (amd64 only) + uses: carlosperate/download-file-action@v2 + id: download-mesa3d + if: matrix.arch == 'amd64' + with: + file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z + file-name: mesa3d.7z + location: ${{ env.downloadPath }} + sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9' + + - name: Extract Mesa3D driver (amd64 only) + if: matrix.arch == 'amd64' + run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z" + + - name: Move opengl32.dll into dist (amd64 only) + if: matrix.arch == 'amd64' + run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\ + + - name: Download EnVar plugin for NSIS + uses: carlosperate/download-file-action@v2 + with: + file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip + file-name: envar_plugin.zip + location: ${{ github.workspace }} + + - name: Extract EnVar plugin + run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip" + + - name: Download ShellExecAsUser plugin for NSIS (amd64 only) + uses: carlosperate/download-file-action@v2 + if: matrix.arch == 'amd64' + with: + file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z + file-name: ShellExecAsUser_amd64-Unicode.7z + location: ${{ github.workspace }} + + - name: Extract ShellExecAsUser plugin (amd64 only) + if: matrix.arch == 'amd64' + run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z" + + - name: Build NSIS installer + uses: joncloud/makensis-action@v3.3 + with: + additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins + script-file: client/installer.nsis + arguments: "/V4 /DARCH=${{ matrix.arch }}" + env: + APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }} + + - name: Rename NSIS installer + run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe + + - name: Install WiX + run: | + dotnet tool install --global wix --version 6.0.2 + wix extension add WixToolset.Util.wixext/6.0.2 + + - name: Build MSI installer + env: + NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}" + run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }} + + - name: Upload installer artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: windows-installer-test-${{ matrix.arch }} + path: | + netbird_installer_test_windows_${{ matrix.arch }}.exe + netbird_installer_test_windows_${{ matrix.arch }}.msi + retention-days: 3 + comment_release_artifacts: name: Comment release artifacts runs-on: ubuntu-latest @@ -554,7 +699,7 @@ jobs: trigger_signer: runs-on: ubuntu-latest - needs: [release, release_ui, release_ui_darwin] + needs: [release, release_ui, release_ui_darwin, test_windows_installer] if: startsWith(github.ref, 'refs/tags/') steps: - name: Trigger binaries sign pipelines diff --git a/client/netbird.wxs b/client/netbird.wxs index 2849bc6b9..6f18b63b5 100644 --- a/client/netbird.wxs +++ b/client/netbird.wxs @@ -68,8 +68,7 @@ - - AUTOSTART = "1" + From f29f5a09784380a3003ef3de5a2c7de4b5733657 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:52:54 +0200 Subject: [PATCH 09/12] [management] add monitoring for nmap update source (#6036) --- .../network_map/controller/controller.go | 15 ++++++-- .../controllers/network_map/interface.go | 4 +- .../controllers/network_map/interface_mock.go | 16 ++++---- .../peers/ephemeral/manager/ephemeral_test.go | 4 +- management/internals/modules/peers/manager.go | 3 +- .../service/manager/l4_port_test.go | 2 +- .../reverseproxy/service/manager/manager.go | 17 +++++---- .../service/manager/manager_test.go | 10 ++--- .../modules/zones/manager/manager.go | 5 ++- .../modules/zones/records/manager/manager.go | 7 ++-- management/server/account.go | 4 +- management/server/account/manager.go | 4 +- management/server/account/manager_mock.go | 16 ++++---- management/server/dns.go | 2 +- management/server/group.go | 16 ++++---- management/server/mock_server/account_mock.go | 12 +++--- management/server/nameserver.go | 6 +-- management/server/networks/manager.go | 3 +- .../server/networks/resources/manager.go | 6 +-- management/server/networks/routers/manager.go | 7 ++-- management/server/peer.go | 8 ++-- management/server/peer_test.go | 4 +- management/server/policy.go | 8 +++- management/server/posture_checks.go | 7 +++- management/server/route.go | 6 +-- .../telemetry/accountmanager_metrics.go | 20 ++++++++++ management/server/types/update_reason.go | 37 +++++++++++++++++++ management/server/user.go | 2 +- 28 files changed, 165 insertions(+), 86 deletions(-) create mode 100644 management/server/types/update_reason.go diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 4b47ecaa0..36de950e9 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -257,7 +257,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID // UpdatePeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { +func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) + } return c.sendUpdateAccountPeers(ctx, accountID) } @@ -331,9 +334,13 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return nil } -func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { +func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) + } + bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) b := bufUpd.(*bufferUpdate) @@ -348,14 +355,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str go func() { defer b.mu.Unlock() - _ = c.UpdateAccountPeers(ctx, accountID) + _ = c.sendUpdateAccountPeers(ctx, accountID) if !b.update.Load() { return } b.update.Store(false) if b.next == nil { b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { - _ = c.UpdateAccountPeers(ctx, accountID) + _ = c.sendUpdateAccountPeers(ctx, accountID) }) return } diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index cfea2d3de..44d8f7d72 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -18,9 +18,9 @@ const ( ) type Controller interface { - UpdateAccountPeers(ctx context.Context, accountID string) error + UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error - BufferUpdateAccountPeers(ctx context.Context, accountID string) error + BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) GetDNSDomain(settings *types.Settings) string StartWarmup(context.Context) diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index 4e86d2973..073a75d3b 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -44,17 +44,17 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder { } // BufferUpdateAccountPeers mocks base method. -func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { +func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason) ret0, _ := ret[0].(error) return ret0 } // BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. -func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call { +func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason) } // CountStreams mocks base method. @@ -238,15 +238,15 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId a } // UpdateAccountPeers mocks base method. -func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error { +func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason) ret0, _ := ret[0].(error) return ret0 } // UpdateAccountPeers indicates an expected call of UpdateAccountPeers. -func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call { +func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason) } diff --git a/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go index fc3010dd1..314e84501 100644 --- a/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go @@ -62,7 +62,7 @@ func (a *MockAccountManager) GetDeletePeerCalls() int { return a.deletePeerCalls } -func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { a.mu.Lock() defer a.mu.Unlock() if a.bufferUpdateCalls == nil { @@ -248,7 +248,7 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { return err } } - mockAM.BufferUpdateAccountPeers(ctx, accountID) + mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{}) return nil }). Times(1) diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index d3f8f44ff..c913efb92 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -178,7 +179,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs } } - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go index 28461641d..fc91b8616 100644 --- a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -85,7 +85,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index ed9d4201b..0fb5f46ff 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -231,7 +232,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate}) return s, nil } @@ -515,7 +516,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s } m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate}) return service, nil } @@ -819,7 +820,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -860,7 +861,7 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster) } - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -916,7 +917,7 @@ func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate}) return nil } @@ -1098,7 +1099,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s } m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate}) serviceURL := "https://" + svc.Domain if service.IsL4Protocol(svc.Mode) { @@ -1210,7 +1211,7 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } @@ -1261,7 +1262,7 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) - m.accountManager.UpdateAccountPeers(ctx, accountID) + m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 54ac8ab18..e9403849c 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -447,7 +447,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { storedActivity = activityID.(activity.Activity) }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -549,7 +549,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { storedActivity = activityID.(activity.Activity) }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -593,7 +593,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) { storedMeta = meta }, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, } mockStore.EXPECT(). @@ -704,7 +704,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, - UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, + UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) }, @@ -1173,7 +1173,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) { mockAcct.EXPECT(). StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any()) mockAcct.EXPECT(). - UpdateAccountPeers(ctx, accountID) + UpdateAccountPeers(ctx, accountID, gomock.Any()) err = mgr.DeleteService(ctx, accountID, userID, service.ID) require.NoError(t, err) diff --git a/management/internals/modules/zones/manager/manager.go b/management/internals/modules/zones/manager/manager.go index 8548dd48c..439671e65 100644 --- a/management/internals/modules/zones/manager/manager.go +++ b/management/internals/modules/zones/manager/manager.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -144,7 +145,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta()) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationUpdate}) return zone, nil } @@ -206,7 +207,7 @@ func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/internals/modules/zones/records/manager/manager.go b/management/internals/modules/zones/records/manager/manager.go index 5374a2ef2..7458b41db 100644 --- a/management/internals/modules/zones/records/manager/manager.go +++ b/management/internals/modules/zones/records/manager/manager.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -95,7 +96,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationCreate}) return record, nil } @@ -154,7 +155,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationUpdate}) return record, nil } @@ -201,7 +202,7 @@ func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneI meta := record.EventMeta(zone.ID, zone.Name) m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta) - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationDelete}) return nil } diff --git a/management/server/account.go b/management/server/account.go index 7d53cef03..4b71ab486 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -400,7 +400,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { - go am.UpdateAccountPeers(ctx, accountID) + go am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceAccountSettings, Operation: types.UpdateOperationUpdate}) } return newSettings, nil @@ -1581,7 +1581,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) - am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) + am.BufferUpdateAccountPeers(ctx, userAuth.AccountId, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate}) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b4516d512..626ed222d 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -124,8 +124,8 @@ type Manager interface { GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error - UpdateAccountPeers(ctx context.Context, accountID string) - BufferUpdateAccountPeers(ctx context.Context, accountID string) + UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) + BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error GetStore() store.Store diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 36e5fe39f..8f3b22ecc 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -111,15 +111,15 @@ func (mr *MockManagerMockRecorder) ApproveUser(ctx, accountID, initiatorUserID, } // BufferUpdateAccountPeers mocks base method. -func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (m *MockManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason) } // BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. -func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason) } // BuildUserInfosForAccount mocks base method. @@ -1597,15 +1597,15 @@ func (mr *MockManagerMockRecorder) UpdateAccountOnboarding(ctx, accountID, userI } // UpdateAccountPeers mocks base method. -func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string) { +func (m *MockManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason) } // UpdateAccountPeers indicates an expected call of UpdateAccountPeers. -func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason) } // UpdateAccountSettings mocks base method. diff --git a/management/server/dns.go b/management/server/dns.go index baf6debc3..c62fa5185 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -86,7 +86,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate}) } return nil diff --git a/management/server/group.go b/management/server/group.go index 7b5b9b86c..e1d05171e 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -117,7 +117,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate}) } return nil @@ -185,7 +185,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -253,7 +253,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate}) } return globalErr @@ -321,7 +321,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return globalErr @@ -493,7 +493,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -531,7 +531,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -559,7 +559,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -597,7 +597,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate}) } return nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ff369355e..ac4d0c6d6 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -128,8 +128,8 @@ type MockAccountManager struct { GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) AllowSyncFunc func(string, uint64) bool - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason) RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error GetIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) @@ -200,15 +200,15 @@ func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userI return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented") } -func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { +func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { if am.UpdateAccountPeersFunc != nil { - am.UpdateAccountPeersFunc(ctx, accountID) + am.UpdateAccountPeersFunc(ctx, accountID, reason) } } -func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { +func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { if am.BufferUpdateAccountPeersFunc != nil { - am.BufferUpdateAccountPeersFunc(ctx, accountID) + am.BufferUpdateAccountPeersFunc(ctx, accountID, reason) } } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 3d8c78912..5859bfb0d 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -82,7 +82,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate}) } return newNSGroup.Copy(), nil @@ -133,7 +133,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate}) } return nil @@ -176,7 +176,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete}) } return nil diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca45..c96b60bb2 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + serverTypes "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -177,7 +178,7 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete}) return nil } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 86f9b6579..5a0e26533 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -162,7 +162,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } - go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate}) return resource, nil } @@ -270,7 +270,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc } }() - go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate}) return resource, nil } @@ -352,7 +352,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete}) return nil } diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a..c7c3f2ff4 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + serverTypes "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -119,7 +120,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) - go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate}) return router, nil } @@ -183,7 +184,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) - go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate}) return router, nil } @@ -217,7 +218,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() - go m.accountManager.UpdateAccountPeers(ctx, accountID) + go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete}) return nil } diff --git a/management/server/peer.go b/management/server/peer.go index 07428539b..d1c52002e 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1221,12 +1221,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - _ = am.networkMapController.UpdateAccountPeers(ctx, accountID) +func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { + _ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason) } -func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) +func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { + _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason) } // UpdateAccountPeer updates a single peer that belongs to an account. diff --git a/management/server/peer_test.go b/management/server/peer_test.go index dae676e77..36809d354 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -975,7 +975,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.UpdateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{}) } duration := time.Since(start) @@ -1033,7 +1033,7 @@ func testUpdateAccountPeers(t *testing.T) { peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID) } - manager.UpdateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id, types.UpdateReason{}) for _, channel := range peerChannels { update := <-channel diff --git a/management/server/policy.go b/management/server/policy.go index 48297ca11..40f3908e3 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -96,7 +96,11 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + policyOp := types.UpdateOperationCreate + if isUpdate { + policyOp = types.UpdateOperationUpdate + } + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp}) } return policy, nil @@ -139,7 +143,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete}) } return nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9562487c0..1e3ce4b8a 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -76,7 +77,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + postureOp := types.UpdateOperationCreate + if isUpdate { + postureOp = types.UpdateOperationUpdate + } + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp}) } return postureChecks, nil diff --git a/management/server/route.go b/management/server/route.go index 2b4f11d05..a9561faf0 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -191,7 +191,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate}) } return newRoute, nil @@ -245,7 +245,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate}) } return nil @@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete}) } return nil diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index 3b1e078eb..518aae7eb 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -4,6 +4,7 @@ import ( "context" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) @@ -11,6 +12,7 @@ import ( type AccountManagerMetrics struct { ctx context.Context updateAccountPeersDurationMs metric.Float64Histogram + updateAccountPeersCounter metric.Int64Counter getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram peerMetaUpdateCount metric.Int64Counter @@ -48,6 +50,13 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + updateAccountPeersCounter, err := meter.Int64Counter("management.account.update.account.peers.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of account peers updates triggered, labeled by resource and operation")) + if err != nil { + return nil, err + } + peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"), metric.WithDescription("Number of updates with new meta data from the peers")) @@ -59,6 +68,7 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs, + updateAccountPeersCounter: updateAccountPeersCounter, networkMapObjectCount: networkMapObjectCount, peerMetaUpdateCount: peerMetaUpdateCount, }, nil @@ -80,6 +90,16 @@ func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { metrics.networkMapObjectCount.Record(metrics.ctx, count) } +// CountUpdateAccountPeersTriggered increments the counter for account peers updates with resource and operation labels. +func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource, operation string) { + metrics.updateAccountPeersCounter.Add(metrics.ctx, 1, + metric.WithAttributes( + attribute.String("resource", resource), + attribute.String("operation", operation), + ), + ) +} + // CountPeerMetUpdate counts the number of peer meta updates func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) diff --git a/management/server/types/update_reason.go b/management/server/types/update_reason.go new file mode 100644 index 000000000..9d752da9a --- /dev/null +++ b/management/server/types/update_reason.go @@ -0,0 +1,37 @@ +package types + +// UpdateReason describes why an account peers update was triggered. +type UpdateReason struct { + Resource UpdateResource + Operation UpdateOperation +} + +// UpdateResource represents the kind of resource that triggered an account peers update. +type UpdateResource string + +const ( + UpdateResourceAccountSettings UpdateResource = "account_settings" + UpdateResourceDNSSettings UpdateResource = "dns_settings" + UpdateResourceGroup UpdateResource = "group" + UpdateResourceNameServerGroup UpdateResource = "nameserver_group" + UpdateResourcePolicy UpdateResource = "policy" + UpdateResourcePostureCheck UpdateResource = "posture_check" + UpdateResourceRoute UpdateResource = "route" + UpdateResourceUser UpdateResource = "user" + UpdateResourcePeer UpdateResource = "peer" + UpdateResourceNetwork UpdateResource = "network" + UpdateResourceNetworkResource UpdateResource = "network_resource" + UpdateResourceNetworkRouter UpdateResource = "network_router" + UpdateResourceService UpdateResource = "service" + UpdateResourceZone UpdateResource = "zone" + UpdateResourceZoneRecord UpdateResource = "zone_record" +) + +// UpdateOperation represents the kind of change that triggered the update. +type UpdateOperation string + +const ( + UpdateOperationCreate UpdateOperation = "create" + UpdateOperationUpdate UpdateOperation = "update" + UpdateOperationDelete UpdateOperation = "delete" +) diff --git a/management/server/user.go b/management/server/user.go index c1f984f2f..b1fb51195 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -675,7 +675,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } - am.UpdateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceUser, Operation: types.UpdateOperationUpdate}) } return updatedUsersInfo, globalErr From dcd1db42ef212d3c8e14a0be451681460fce0a7c Mon Sep 17 00:00:00 2001 From: Nicolas Frati Date: Thu, 30 Apr 2026 17:21:35 +0200 Subject: [PATCH 10/12] [management] Enable PAT creation during setup (#6003) * enable pat creation on setup * remove logic from handler towards setup service * fix lint issue * fix rollback on account id returning empty * fix coderabbit comments * fix setup PAT rollback behavior --- management/server/account/pat.go | 8 + management/server/http/handler.go | 6 +- .../handlers/instance/instance_handler.go | 31 +- .../instance/instance_handler_test.go | 254 +++++++++++++- management/server/instance/manager.go | 54 +++ management/server/instance/manager_test.go | 87 ++++- management/server/instance/setup_service.go | 216 ++++++++++++ .../server/instance/setup_service_test.go | 318 ++++++++++++++++++ management/server/user.go | 5 +- shared/management/http/api/openapi.yml | 28 +- shared/management/http/api/types.gen.go | 9 + 11 files changed, 997 insertions(+), 19 deletions(-) create mode 100644 management/server/account/pat.go create mode 100644 management/server/instance/setup_service.go create mode 100644 management/server/instance/setup_service_test.go diff --git a/management/server/account/pat.go b/management/server/account/pat.go new file mode 100644 index 000000000..8e5e3e3f9 --- /dev/null +++ b/management/server/account/pat.go @@ -0,0 +1,8 @@ +package account + +const ( + // PATMinExpireDays is the minimum allowed Personal Access Token expiration in days. + PATMinExpireDays = 1 + // PATMaxExpireDays is the maximum allowed Personal Access Token expiration in days. + PATMaxExpireDays = 365 +) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 56b2d8203..b9ea605d3 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -62,9 +62,7 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) -const ( - apiPrefix = "/api" -) +const apiPrefix = "/api" // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) { @@ -141,7 +139,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks zonesManager.RegisterEndpoints(router, zManager) recordsManager.RegisterEndpoints(router, rManager) idp.AddEndpoints(accountManager, router) - instance.AddEndpoints(instanceManager, router) + instance.AddEndpoints(instanceManager, accountManager, router) instance.AddVersionEndpoint(instanceManager, router) if serviceManager != nil && reverseProxyDomainManager != nil { reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index cd9fae6b8..e98ce4d7c 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" nbinstance "github.com/netbirdio/netbird/management/server/instance" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -15,13 +16,15 @@ import ( // handler handles the instance setup HTTP endpoints type handler struct { instanceManager nbinstance.Manager + setupManager *nbinstance.SetupService } // AddEndpoints registers the instance setup endpoints. // These endpoints bypass authentication for initial setup. -func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) { +func AddEndpoints(instanceManager nbinstance.Manager, accountManager account.Manager, router *mux.Router) { h := &handler{ instanceManager: instanceManager, + setupManager: nbinstance.NewSetupService(instanceManager, accountManager), } router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS") @@ -55,24 +58,36 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { // setup creates the initial admin user for the instance. // This endpoint is unauthenticated but only works when setup is required. func (h *handler) setup(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + var req api.SetupRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w) return } - userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name) + result, err := h.setupManager.SetupOwner(ctx, req.Email, req.Password, req.Name, nbinstance.SetupOptions{ + CreatePAT: req.CreatePat != nil && *req.CreatePat, + PATExpireInDays: req.PatExpireIn, + }) if err != nil { - util.WriteError(r.Context(), err, w) + util.WriteError(ctx, err, w) return } - log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email) + log.WithContext(ctx).Infof("instance setup completed: created user %s", req.Email) - util.WriteJSONObject(r.Context(), w, api.SetupResponse{ - UserId: userData.ID, - Email: userData.Email, - }) + resp := api.SetupResponse{ + UserId: result.User.ID, + Email: result.User.Email, + } + + if result.PATPlainToken != "" { + resp.PersonalAccessToken = &result.PATPlainToken + } + + w.Header().Set("Cache-Control", "no-store") + util.WriteJSONObject(ctx, w, resp) } // getVersionInfo returns version information for NetBird components. diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go index 470079c85..711e01964 100644 --- a/management/server/http/handlers/instance/instance_handler_test.go +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -10,12 +10,18 @@ import ( "net/mail" "testing" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/idp" nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/management/server/mock_server" + nbstore "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -25,6 +31,7 @@ type mockInstanceManager struct { isSetupRequired bool isSetupRequiredFn func(ctx context.Context) (bool, error) createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) + rollbackSetupFn func(ctx context.Context, userID string) error getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error) } @@ -67,6 +74,13 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo }, nil } +func (m *mockInstanceManager) RollbackSetup(ctx context.Context, userID string) error { + if m.rollbackSetupFn != nil { + return m.rollbackSetupFn(ctx, userID) + } + return nil +} + func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) { if m.getVersionInfoFn != nil { return m.getVersionInfoFn(ctx) @@ -82,8 +96,12 @@ func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.V var _ nbinstance.Manager = (*mockInstanceManager)(nil) func setupTestRouter(manager nbinstance.Manager) *mux.Router { + return setupTestRouterWithPAT(manager, nil) +} + +func setupTestRouterWithPAT(manager nbinstance.Manager, accountManager account.Manager) *mux.Router { router := mux.NewRouter() - AddEndpoints(manager, router) + AddEndpoints(manager, accountManager, router) return router } @@ -161,6 +179,7 @@ func TestSetup_Success(t *testing.T) { router.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) var response api.SetupResponse err := json.NewDecoder(rec.Body).Decode(&response) @@ -293,6 +312,239 @@ func TestSetup_ManagerError(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Code) } +func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "false") + + manager := &mockInstanceManager{isSetupRequired: true} + // NB_SETUP_PAT_ENABLED=false: request fields must be silently ignored + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Nil(t, response.PersonalAccessToken) +} + +func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Nil(t, response.PersonalAccessToken) +} + +func TestSetup_PAT_MissingExpireIn_DefaultsToOneDay(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + createCalled := false + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "u1", Email: email, Name: name}, nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "u1", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "u1", initiator) + assert.Equal(t, "u1", target) + assert.Equal(t, "setup-token", name) + assert.Equal(t, 1, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + } + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) + assert.True(t, createCalled) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + require.NotNil(t, response.PersonalAccessToken) + assert.Equal(t, "nbp_plain", *response.PersonalAccessToken) +} + +func TestSetup_PAT_ExpireOutOfRange(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{isSetupRequired: true} + router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnprocessableEntity, rec.Code) +} + +func TestSetup_PAT_Success(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + } + + gotAccountArgs := struct { + userID string + email string + }{} + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + gotAccountArgs.userID = userAuth.UserId + gotAccountArgs.email = userAuth.Email + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiator) + assert.Equal(t, "owner-id", target) + assert.Equal(t, "setup-token", name) + assert.Equal(t, 30, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "no-store", rec.Header().Get("Cache-Control")) + var response api.SetupResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&response)) + assert.Equal(t, "owner-id", response.UserId) + require.NotNil(t, response.PersonalAccessToken) + assert.Equal(t, "nbp_plain", *response.PersonalAccessToken) + assert.Equal(t, "owner-id", gotAccountArgs.userID) + assert.Equal(t, "admin@example.com", gotAccountArgs.email) +} + +func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("", status.NewAccountNotFoundError("owner-id")) + + rolledBackFor := "" + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, userID string) error { + rolledBackFor = userID + return nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "", errors.New("db down") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called with the created user id") +} + +func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) { + t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rolledBackFor := "" + manager := &mockInstanceManager{ + isSetupRequired: true, + createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) { + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, userID string) error { + rolledBackFor = userID + return nil + }, + } + accountMgr := &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, status.Errorf(status.Internal, "token store unavailable") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + } + + router := setupTestRouterWithPAT(manager, accountMgr) + + body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}` + req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called when CreatePAT fails") +} + func TestGetVersionInfo_Success(t *testing.T) { manager := &mockInstanceManager{} router := mux.NewRouter() diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 9579d7a35..2c355bb3b 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/dexidp/dex/storage" goversion "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" @@ -60,6 +61,13 @@ type Manager interface { // This should only be called when IsSetupRequired returns true. CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) + // RollbackSetup reverses a successful CreateOwnerUser by deleting the user + // from the embedded IDP and reloading setupRequired from persistent state, so + // /api/setup can be retried only when no accounts or local users remain. Used + // when post-user steps (account or PAT creation) fail and the caller wants a + // clean slate. + RollbackSetup(ctx context.Context, userID string) error + // GetVersionInfo returns version information for NetBird components. GetVersionInfo(ctx context.Context) (*VersionInfo, error) } @@ -70,6 +78,7 @@ type instanceStore interface { type embeddedIdP interface { CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) + DeleteUser(ctx context.Context, userID string) error GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error) } @@ -187,6 +196,51 @@ func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, n return userData, nil } +// RollbackSetup undoes a successful CreateOwnerUser: deletes the user from the +// embedded IDP and reloads setupRequired from persistent state. +func (m *DefaultManager) RollbackSetup(ctx context.Context, userID string) error { + if m.embeddedIdpManager == nil { + return errors.New("embedded IDP is not enabled") + } + + var deleteErr error + if err := m.embeddedIdpManager.DeleteUser(ctx, userID); err != nil { + if isNotFoundError(err) { + log.WithContext(ctx).Debugf("setup rollback user %s already deleted", userID) + } else { + deleteErr = fmt.Errorf("failed to delete user from embedded IdP: %w", err) + } + } + + if err := m.loadSetupRequired(ctx); err != nil { + reloadErr := fmt.Errorf("failed to reload setup state after rollback: %w", err) + if deleteErr != nil { + return errors.Join(deleteErr, reloadErr) + } + return reloadErr + } + + if deleteErr != nil { + return deleteErr + } + + log.WithContext(ctx).Infof("rolled back setup for user %s", userID) + return nil +} + +func isNotFoundError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, storage.ErrNotFound) { + return true + } + if s, ok := status.FromError(err); ok { + return s.Type() == status.NotFound + } + return false +} + func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error { numAccounts, err := m.store.GetAccountsCounter(ctx) if err != nil { diff --git a/management/server/instance/manager_test.go b/management/server/instance/manager_test.go index e3be9cfea..5ffb016de 100644 --- a/management/server/instance/manager_test.go +++ b/management/server/instance/manager_test.go @@ -10,16 +10,19 @@ import ( "testing" "time" + "github.com/dexidp/dex/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/shared/management/status" ) type mockIdP struct { - mu sync.Mutex - createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) - users map[string][]*idp.UserData + mu sync.Mutex + createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error) + deleteUserFunc func(ctx context.Context, userID string) error + users map[string][]*idp.UserData getAllAccountsErr error } @@ -30,6 +33,13 @@ func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, n return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil } +func (m *mockIdP) DeleteUser(ctx context.Context, userID string) error { + if m.deleteUserFunc != nil { + return m.deleteUserFunc(ctx, userID) + } + return nil +} + func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) { m.mu.Lock() defer m.mu.Unlock() @@ -223,6 +233,77 @@ func TestIsSetupRequired_ReturnsFlag(t *testing.T) { assert.False(t, required) } +func TestRollbackSetup_UserAlreadyDeletedIsSuccess(t *testing.T) { + tests := []struct { + name string + err error + }{ + { + name: "management status not found", + err: status.NewUserNotFoundError("owner-id"), + }, + { + name: "dex storage not found", + err: fmt.Errorf("failed to get user for deletion: %w", storage.ErrNotFound), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, userID string) error { + assert.Equal(t, "owner-id", userID) + return tt.err + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + mgr.setupRequired = false + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.NoError(t, err) + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required, "setup should be required when no accounts or local users remain") + }) + } +} + +func TestRollbackSetup_RecomputesSetupStateWhenAccountStillExists(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, _ string) error { + return status.NewUserNotFoundError("owner-id") + }, + } + mgr := newTestManager(idpMock, &mockStore{accountsCount: 1}) + mgr.setupRequired = true + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.NoError(t, err) + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.False(t, required, "setup should not be required while an account still exists") +} + +func TestRollbackSetup_ReturnsDeleteErrorButReloadsSetupState(t *testing.T) { + idpMock := &mockIdP{ + deleteUserFunc: func(_ context.Context, _ string) error { + return errors.New("idp unavailable") + }, + } + mgr := newTestManager(idpMock, &mockStore{}) + mgr.setupRequired = false + + err := mgr.RollbackSetup(context.Background(), "owner-id") + require.Error(t, err) + assert.Contains(t, err.Error(), "idp unavailable") + + required, err := mgr.IsSetupRequired(context.Background()) + require.NoError(t, err) + assert.True(t, required, "setup state should be reloaded even when user deletion fails") +} + func TestDefaultManager_ValidateSetupRequest(t *testing.T) { manager := &DefaultManager{setupRequired: true} diff --git a/management/server/instance/setup_service.go b/management/server/instance/setup_service.go new file mode 100644 index 000000000..92a4923be --- /dev/null +++ b/management/server/instance/setup_service.go @@ -0,0 +1,216 @@ +package instance + +import ( + "context" + "fmt" + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + setupPATTokenName = "setup-token" + + // SetupPATEnabledEnvKey enables setup-time Personal Access Token creation. + SetupPATEnabledEnvKey = "NB_SETUP_PAT_ENABLED" + + setupPATDefaultExpireDays = 1 +) + +// SetupOptions controls optional work performed during initial instance setup. +type SetupOptions struct { + // CreatePAT requests creation of a setup Personal Access Token. It is honored + // only when SetupPATEnabledEnvKey is set to "true". + CreatePAT bool + // PATExpireInDays defaults to 1 day when CreatePAT is requested and setup PAT + // creation is enabled. + PATExpireInDays *int +} + +// SetupResult contains resources created during initial instance setup. +type SetupResult struct { + User *idp.UserData + PATPlainToken string +} + +// SetupService orchestrates the initial setup use case across the instance and +// account bounded contexts and owns the compensation logic when a later step +// fails. +type SetupService struct { + instanceManager Manager + accountManager account.Manager + setupPATEnabled bool +} + +// NewSetupService creates a setup use-case service. +func NewSetupService(instanceManager Manager, accountManager account.Manager) *SetupService { + return &SetupService{ + instanceManager: instanceManager, + accountManager: accountManager, + setupPATEnabled: os.Getenv(SetupPATEnabledEnvKey) == "true", + } +} + +func normalizeSetupOptions(opts SetupOptions, setupPATEnabled bool) (SetupOptions, error) { + if !opts.CreatePAT { + return opts, nil + } + + if !setupPATEnabled { + opts.CreatePAT = false + opts.PATExpireInDays = nil + return opts, nil + } + + if opts.PATExpireInDays == nil { + defaultExpireInDays := setupPATDefaultExpireDays + opts.PATExpireInDays = &defaultExpireInDays + } + + if *opts.PATExpireInDays < account.PATMinExpireDays || *opts.PATExpireInDays > account.PATMaxExpireDays { + return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays) + } + + return opts, nil +} + +// SetupOwner creates the initial owner user and, when requested and enabled by +// SetupPATEnabledEnvKey, provisions the account and a setup Personal Access +// Token. If account or PAT provisioning fails, created resources are rolled +// back so setup can be retried. If account rollback fails, user rollback is +// skipped to avoid leaving an account without its owner user. +func (m *SetupService) SetupOwner(ctx context.Context, email, password, name string, opts SetupOptions) (*SetupResult, error) { + opts, err := normalizeSetupOptions(opts, m.setupPATEnabled) + if err != nil { + return nil, err + } + + if opts.CreatePAT && m.accountManager == nil { + return nil, fmt.Errorf("account manager is required to create setup PAT") + } + + userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name) + if err != nil { + return nil, err + } + + result := &SetupResult{User: userData} + if !opts.CreatePAT { + return result, nil + } + + userAuth := auth.UserAuth{ + UserId: userData.ID, + Email: userData.Email, + Name: userData.Name, + } + + accountID, err := m.accountManager.GetAccountIDByUserID(ctx, userAuth) + if err != nil { + err = fmt.Errorf("create account for setup user: %w", err) + if rollbackErr := m.rollbackSetup(ctx, userData.ID, "account provisioning failed", err, ""); rollbackErr != nil { + return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr) + } + return nil, err + } + + pat, err := m.accountManager.CreatePAT(ctx, accountID, userData.ID, userData.ID, setupPATTokenName, *opts.PATExpireInDays) + if err != nil { + err = fmt.Errorf("create setup PAT: %w", err) + if rollbackErr := m.rollbackSetup(ctx, userData.ID, "setup PAT provisioning failed", err, accountID); rollbackErr != nil { + return nil, fmt.Errorf("%w; failed to roll back setup resources: %v", err, rollbackErr) + } + return nil, err + } + + result.PATPlainToken = pat.PlainToken + return result, nil +} + +func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) error { + if accountID == "" { + resolvedAccountID, err := m.lookupSetupAccountIDForRollback(ctx, userID) + if err != nil { + rollbackErr := fmt.Errorf("resolve setup account for rollback: %w", err) + log.WithContext(ctx).Errorf("failed to resolve setup account for user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr) + return rollbackErr + } + accountID = resolvedAccountID + } + + if accountID != "" { + if err := m.rollbackSetupAccount(ctx, accountID); err != nil { + rollbackErr := fmt.Errorf("roll back setup account %s: %w", accountID, err) + log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, rollbackErr) + return rollbackErr + } + log.WithContext(ctx).Warnf("rolled back setup account %s for user %s after %s: %v", accountID, userID, reason, origErr) + } + + if err := m.instanceManager.RollbackSetup(ctx, userID); err != nil { + rollbackErr := fmt.Errorf("roll back setup user %s: %w", userID, err) + log.WithContext(ctx).Errorf("failed to roll back setup user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, rollbackErr) + return rollbackErr + } + log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr) + return nil +} + +func (m *SetupService) lookupSetupAccountIDForRollback(ctx context.Context, userID string) (string, error) { + if m.accountManager == nil { + return "", fmt.Errorf("account manager is required to resolve setup account") + } + + accountStore := m.accountManager.GetStore() + if accountStore == nil { + return "", fmt.Errorf("account store is unavailable") + } + + accountID, err := accountStore.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID) + if err != nil { + if isNotFoundError(err) { + return "", nil + } + return "", fmt.Errorf("get setup account ID for rollback: %w", err) + } + + return accountID, nil +} + +// rollbackSetupAccount removes only the setup-created account data from the +// store. It intentionally avoids accountManager.DeleteAccount because the normal +// account deletion path also deletes users from the IdP; embedded IdP cleanup is +// owned by instanceManager.RollbackSetup. +func (m *SetupService) rollbackSetupAccount(ctx context.Context, accountID string) error { + if m.accountManager == nil { + return fmt.Errorf("account manager is required to roll back setup account") + } + + accountStore := m.accountManager.GetStore() + if accountStore == nil { + return fmt.Errorf("account store is unavailable") + } + + account, err := accountStore.GetAccount(ctx, accountID) + if err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("get setup account for rollback: %w", err) + } + + if err := accountStore.DeleteAccount(ctx, account); err != nil { + if isNotFoundError(err) { + return nil + } + return fmt.Errorf("delete setup account for rollback: %w", err) + } + + return nil +} diff --git a/management/server/instance/setup_service_test.go b/management/server/instance/setup_service_test.go new file mode 100644 index 000000000..12ec7d0fa --- /dev/null +++ b/management/server/instance/setup_service_test.go @@ -0,0 +1,318 @@ +package instance + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/mock_server" + nbstore "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/status" +) + +type setupInstanceManagerMock struct { + createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error) + rollbackSetupFn func(ctx context.Context, userID string) error +} + +func (m *setupInstanceManagerMock) IsSetupRequired(context.Context) (bool, error) { + return true, nil +} + +func (m *setupInstanceManagerMock) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) { + if m.createOwnerUserFn != nil { + return m.createOwnerUserFn(ctx, email, password, name) + } + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil +} + +func (m *setupInstanceManagerMock) RollbackSetup(ctx context.Context, userID string) error { + if m.rollbackSetupFn != nil { + return m.rollbackSetupFn(ctx, userID) + } + return nil +} + +func (m *setupInstanceManagerMock) GetVersionInfo(context.Context) (*VersionInfo, error) { + return &VersionInfo{}, nil +} + +var _ Manager = (*setupInstanceManagerMock)(nil) + +func intPtr(v int) *int { + return &v +} + +func TestSetupOwner_PATFeatureDisabled_IgnoresCreatePAT(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "false") + + createCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalls++ + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + }, + &mock_server.MockAccountManager{}, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "owner-id", result.User.ID) + assert.Empty(t, result.PATPlainToken) + assert.Equal(t, 1, createCalls) +} + +func TestSetupOwner_PATFeatureEnabled_MissingExpireDefaultsToOneDay(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + createCalled := false + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiatorUserID) + assert.Equal(t, "owner-id", targetUserID) + assert.Equal(t, setupPATTokenName, tokenName) + assert.Equal(t, setupPATDefaultExpireDays, expiresIn) + return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, createCalled) + assert.Equal(t, "nbp_plain", result.PATPlainToken) +} + +func TestSetupOwner_PATFeatureEnabled_MissingAccountManagerFailsBeforeCreateUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + createCalled := false + rollbackCalled := false + setupManager := NewSetupService( + &setupInstanceManagerMock{ + createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) { + createCalled = true + return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil + }, + rollbackSetupFn: func(_ context.Context, _ string) error { + rollbackCalled = true + return nil + }, + }, + nil, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "account manager is required") + assert.False(t, createCalled) + assert.False(t, rollbackCalled) +} + +func TestSetupOwner_AccountProvisioningFails_RollsBackSideEffectAccountAndUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("acc-1", nil) + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rolledBackFor := "" + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + rolledBackFor = userID + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "", errors.New("metadata update failed") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create account for setup user") + assert.Equal(t, "owner-id", rolledBackFor) + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil) + + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + assert.Equal(t, "owner-id", userID) + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) { + assert.Equal(t, "owner-id", userAuth.UserId) + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + assert.Equal(t, "acc-1", accountID) + assert.Equal(t, "owner-id", initiatorUserID) + assert.Equal(t, "owner-id", targetUserID) + assert.Equal(t, setupPATTokenName, tokenName) + assert.Equal(t, 30, expiresIn) + return nil, status.Errorf(status.Internal, "token store unavailable") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_AccountAlreadyGoneStillRollsBackUser(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(nil, status.NewAccountNotFoundError("acc-1")) + + rolledBackFor := "" + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + rolledBackFor = userID + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, errors.New("token failure") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Equal(t, "owner-id", rolledBackFor) + assert.Equal(t, 1, rollbackCalls) +} + +func TestSetupOwner_CreatePATFails_AccountRollbackFailureStopsBeforeUserRollback(t *testing.T) { + t.Setenv(SetupPATEnabledEnvKey, "true") + + ctrl := gomock.NewController(t) + accountStore := nbstore.NewMockStore(ctrl) + account := &types.Account{Id: "acc-1"} + accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil) + accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(errors.New("delete failed")) + + rollbackCalls := 0 + setupManager := NewSetupService( + &setupInstanceManagerMock{ + rollbackSetupFn: func(_ context.Context, userID string) error { + rollbackCalls++ + return nil + }, + }, + &mock_server.MockAccountManager{ + GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) { + return "acc-1", nil + }, + CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) { + return nil, errors.New("token failure") + }, + GetStoreFunc: func() nbstore.Store { + return accountStore + }, + }, + ) + + result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{ + CreatePAT: true, + PATExpireInDays: intPtr(30), + }) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "create setup PAT") + assert.Contains(t, err.Error(), "failed to roll back setup resources") + assert.Equal(t, 0, rollbackCalls) +} diff --git a/management/server/user.go b/management/server/user.go index b1fb51195..43e0a9821 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/idp/dex" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -395,8 +396,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") } - if expiresIn < 1 || expiresIn > 365 { - return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") + if expiresIn < account.PATMinExpireDays || expiresIn > account.PATMaxExpireDays { + return nil, status.Errorf(status.InvalidArgument, "expiration has to be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays) } allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create) diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index b70f89499..c5fdbfbe0 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -3426,6 +3426,17 @@ components: description: Display name for the admin user (defaults to email if not provided) type: string example: Admin User + create_pat: + description: If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled. + type: boolean + example: true + pat_expire_in: + description: Expiration of the Personal Access Token in days. Applies only when create_pat is true and the server feature is enabled. Defaults to 1 day when omitted. + type: integer + minimum: 1 + maximum: 365 + default: 1 + example: 30 required: - email - password @@ -3442,6 +3453,12 @@ components: description: Email address of the created user type: string example: admin@example.com + personal_access_token: + description: Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server. + type: string + format: password + readOnly: true + example: nbp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx required: - user_id - email @@ -4980,7 +4997,10 @@ paths: /api/setup: post: summary: Setup Instance - description: Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled). + description: | + Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled). + + When the management server is started with `NB_SETUP_PAT_ENABLED=true` and the request includes `create_pat: true`, the endpoint also provisions the NetBird account for the new owner user and returns the plain text Personal Access Token in `personal_access_token`. The optional `pat_expire_in` value applies only when `create_pat` is true and defaults to 1 day when omitted. If a post-user step fails, setup-created resources are rolled back when safe; if account cleanup fails, the owner user is left in place to avoid leaving an account without its admin user. tags: [ Instance ] security: [ ] requestBody: @@ -4993,6 +5013,12 @@ paths: responses: '200': description: Setup completed successfully + headers: + Cache-Control: + description: Always set to no-store because the response may contain a one-time plain text Personal Access Token. + schema: + type: string + example: no-store content: application/json: schema: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index d56cb9b74..11cb8e46a 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -4297,6 +4297,9 @@ type SetupKeyRequest struct { // SetupRequest Request to set up the initial admin user type SetupRequest struct { + // CreatePat If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled. + CreatePat *bool `json:"create_pat,omitempty"` + // Email Email address for the admin user Email string `json:"email"` @@ -4305,6 +4308,9 @@ type SetupRequest struct { // Password Password for the admin user (minimum 8 characters) Password string `json:"password"` + + // PatExpireIn Expiration of the Personal Access Token in days. Applies only when create_pat is true and the server feature is enabled. Defaults to 1 day when omitted. + PatExpireIn *int `json:"pat_expire_in,omitempty"` } // SetupResponse Response after successful instance setup @@ -4312,6 +4318,9 @@ type SetupResponse struct { // Email Email address of the created user Email string `json:"email"` + // PersonalAccessToken Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server. + PersonalAccessToken *string `json:"personal_access_token,omitempty"` + // UserId The ID of the created user UserId string `json:"user_id"` } From c4b2da4c92520d006af448d90c6f533352b10769 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Thu, 30 Apr 2026 18:36:50 +0200 Subject: [PATCH 11/12] [management] Add public connection ipv4 and ipv6 posture check (#6038) This change enables admins to configure posture checks for connecting public IPs of their peers. It changes the behavior of the check as well and now the evaluation is if the received network is part of the configured network. --- management/server/posture/network.go | 47 ++++- management/server/posture/network_test.go | 200 ++++++++++++++++++++++ shared/management/http/api/openapi.yml | 9 +- shared/management/http/api/types.gen.go | 6 +- 4 files changed, 247 insertions(+), 15 deletions(-) diff --git a/management/server/posture/network.go b/management/server/posture/network.go index f78744143..4b4b3ccaa 100644 --- a/management/server/posture/network.go +++ b/management/server/posture/network.go @@ -17,19 +17,48 @@ type PeerNetworkRangeCheck struct { var _ Check = (*PeerNetworkRangeCheck)(nil) +// prefixContains reports whether outer fully contains inner (equal counts as contained). +// Requires the same address family, that outer is no more specific than inner (its +// netmask is shorter or equal), and that inner's network address falls inside outer. +// This is stricter than netip.Prefix.Contains(Addr) — a peer's /24 NIC will not match a +// configured /32 rule, since the rule covers a single host but the NIC describes a whole +// subnet whose host bits are unknown. +func prefixContains(outer, inner netip.Prefix) bool { + outer = outer.Masked() + inner = inner.Masked() + return outer.Bits() <= inner.Bits() && + outer.Addr().BitLen() == inner.Addr().BitLen() && // same family + outer.Contains(inner.Addr()) +} + +// Check evaluates configured ranges against the peer's local network interface prefixes +// and its public connection IP (as a /32 or /128). A configured range matches when it +// fully contains one of those prefixes, so operators can target both private subnets +// and public CIDRs (e.g. 1.0.0.0/24, 2.2.2.2/32). Including the connection IP is what +// lets a public-range posture check work — peer.Meta.NetworkAddresses only carries +// local NIC addresses. func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { - if len(peer.Meta.NetworkAddresses) == 0 { + peerPrefixes := make([]netip.Prefix, 0, len(peer.Meta.NetworkAddresses)+1) + for _, peerNetAddr := range peer.Meta.NetworkAddresses { + peerPrefixes = append(peerPrefixes, peerNetAddr.NetIP) + } + // Unmap collapses 4-in-6 forms (::ffff:a.b.c.d) so an IPv4 range matches. + if connIP := peer.Location.ConnectionIP; len(connIP) > 0 { + if addr, ok := netip.AddrFromSlice(connIP); ok { + addr = addr.Unmap() + peerPrefixes = append(peerPrefixes, netip.PrefixFrom(addr, addr.BitLen())) + } + } + + if len(peerPrefixes) == 0 { return false, fmt.Errorf("peer's does not contain peer network range addresses") } - maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges)) - for _, prefix := range p.Ranges { - maskedPrefixes = append(maskedPrefixes, prefix.Masked()) - } - - for _, peerNetAddr := range peer.Meta.NetworkAddresses { - peerMaskedPrefix := peerNetAddr.NetIP.Masked() - if slices.Contains(maskedPrefixes, peerMaskedPrefix) { + for _, peerPrefix := range peerPrefixes { + for _, rangePrefix := range p.Ranges { + if !prefixContains(rangePrefix, peerPrefix) { + continue + } switch p.Action { case CheckActionDeny: return false, nil diff --git a/management/server/posture/network_test.go b/management/server/posture/network_test.go index a841bbe08..4af394c62 100644 --- a/management/server/posture/network_test.go +++ b/management/server/posture/network_test.go @@ -2,6 +2,7 @@ package posture import ( "context" + "net" "net/netip" "testing" @@ -134,6 +135,205 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) { wantErr: true, isValid: false, }, + { + name: "Peer connection IP matches the denied /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.123/24")}, + }, + }, + Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Peer connection IP does not match the denied /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.123/24")}, + }, + }, + Location: nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Peer connection IP matches the allowed /32 with no NetworkAddresses", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("109.41.115.194")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "IPv6 connection IP matches the denied /128", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::1")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "IPv6 connection IP does not match the denied /128", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::1/128"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("2001:db8::2")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "IPv4-mapped IPv6 connection IP matches IPv4 /32", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("::ffff:109.41.115.194")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Connection IP falls inside an allowed /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + netip.MustParsePrefix("2.2.2.2/32"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.55")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Connection IP falls inside an allowed /23 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("3.0.0.0/23"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("3.0.1.200")}, + }, + wantErr: false, + isValid: true, + }, + { + name: "Connection IP outside the allowed /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.1.5")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Connection IP inside a denied /24 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("1.0.0.0/24"), + }, + }, + peer: nbpeer.Peer{ + Location: nbpeer.Location{ConnectionIP: net.ParseIP("1.0.0.7")}, + }, + wantErr: false, + isValid: false, + }, + { + name: "Local NIC /24 does not match a /32 rule even if host bit lines up", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.5/32"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.0.5/24")}, + }, + }, + }, + wantErr: false, + isValid: false, + }, + { + name: "Local NIC address inside an allowed /16 range", + check: PeerNetworkRangeCheck{ + Action: CheckActionAllow, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + peer: nbpeer.Peer{ + Meta: nbpeer.PeerSystemMeta{ + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.5.7/24")}, + }, + }, + }, + wantErr: false, + isValid: true, + }, + { + name: "Empty NetworkAddresses and empty ConnectionIP still errors", + check: PeerNetworkRangeCheck{ + Action: CheckActionDeny, + Ranges: []netip.Prefix{ + netip.MustParsePrefix("109.41.115.194/32"), + }, + }, + peer: nbpeer.Peer{}, + wantErr: true, + isValid: false, + }, } for _, tt := range tests { diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index c5fdbfbe0..327e20614 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -1687,15 +1687,18 @@ components: - locations - action PeerNetworkRangeCheck: - description: Posture check for allow or deny access based on peer local network addresses + description: | + Posture check for allow or deny access based on the peer's IP addresses. A range matches when it + contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, + so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. type: object properties: ranges: - description: List of peer network ranges in CIDR notation + description: List of network ranges in CIDR notation, matched against the peer's local interface IPs and its public connection IP type: array items: type: string - example: [ "192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56" ] + example: [ "192.168.1.0/24", "10.0.0.0/8", "1.0.0.0/24", "2.2.2.2/32", "2001:db8:1234:1a00::/56" ] action: description: Action to take upon policy match type: string diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 11cb8e46a..dc916f81a 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1626,7 +1626,7 @@ type Checks struct { // OsVersionCheck Posture check for the version of operating system OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"` - // PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses + // PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"` // ProcessCheck Posture Check for binaries exist and are running in the peer’s system @@ -3312,12 +3312,12 @@ type PeerMinimum struct { Name string `json:"name"` } -// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses +// PeerNetworkRangeCheck Posture check for allow or deny access based on the peer's IP addresses. A range matches when it contains any of the peer's local network interface IPs or its public connection (NAT egress) IP, so ranges may target private subnets, public CIDRs, or single hosts via a /32 or /128. type PeerNetworkRangeCheck struct { // Action Action to take upon policy match Action PeerNetworkRangeCheckAction `json:"action"` - // Ranges List of peer network ranges in CIDR notation + // Ranges List of network ranges in CIDR notation, matched against the peer's local interface IPs and its public connection IP Ranges []string `json:"ranges"` } From 057d651d2e1f27c539a16c010c34e1ba88a117de Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 4 May 2026 18:28:56 +0900 Subject: [PATCH 12/12] [client, proxy] Add packet capture to debug bundle and CLI (#5891) --- client/Dockerfile | 1 + client/Dockerfile-rootless | 1 + client/cmd/capture.go | 196 ++++++ client/cmd/debug.go | 41 ++ client/cmd/root.go | 1 + client/cmd/service.go | 1 + client/cmd/service_controller.go | 2 +- client/cmd/service_installer.go | 4 + client/cmd/service_params.go | 6 + client/cmd/service_params_test.go | 1 + client/cmd/testutil_test.go | 2 +- client/embed/capture.go | 65 ++ client/embed/embed.go | 53 +- client/firewall/uspfilter/filter.go | 32 +- .../firewall/uspfilter/forwarder/endpoint.go | 17 +- .../firewall/uspfilter/forwarder/forwarder.go | 10 + client/firewall/uspfilter/forwarder/icmp.go | 4 + client/iface/device/device_filter.go | 54 +- client/iface/device/device_filter_test.go | 2 +- client/internal/debug/debug.go | 33 +- client/internal/engine.go | 65 ++ client/internal/lazyconn/manager/manager.go | 5 +- client/internal/netflow/store/memory.go | 4 +- .../internal/routeselector/routeselector.go | 13 +- client/proto/daemon.pb.go | 563 ++++++++++++---- client/proto/daemon.proto | 34 + client/proto/daemon_grpc.pb.go | 131 +++- client/server/capture.go | 365 ++++++++++ client/server/debug.go | 5 +- client/server/server.go | 10 +- client/server/server_test.go | 6 +- client/server/setconfig_test.go | 2 +- client/ui/debug.go | 187 ++--- client/wasm/cmd/main.go | 96 +++ client/wasm/internal/capture/capture.go | 176 +++++ management/server/account_test.go | 2 +- proxy/cmd/proxy/cmd/debug.go | 114 ++++ proxy/internal/debug/client.go | 70 ++ proxy/internal/debug/handler.go | 77 +++ util/capture/afpacket_linux.go | 199 ++++++ util/capture/afpacket_stub.go | 26 + util/capture/capture.go | 59 ++ util/capture/filter.go | 528 +++++++++++++++ util/capture/filter_test.go | 263 ++++++++ util/capture/pcap.go | 85 +++ util/capture/pcap_test.go | 68 ++ util/capture/session.go | 213 ++++++ util/capture/session_test.go | 144 ++++ util/capture/text.go | 638 ++++++++++++++++++ 49 files changed, 4421 insertions(+), 253 deletions(-) create mode 100644 client/cmd/capture.go create mode 100644 client/embed/capture.go create mode 100644 client/server/capture.go create mode 100644 client/wasm/internal/capture/capture.go create mode 100644 util/capture/afpacket_linux.go create mode 100644 util/capture/afpacket_stub.go create mode 100644 util/capture/capture.go create mode 100644 util/capture/filter.go create mode 100644 util/capture/filter_test.go create mode 100644 util/capture/pcap.go create mode 100644 util/capture/pcap_test.go create mode 100644 util/capture/session.go create mode 100644 util/capture/session_test.go create mode 100644 util/capture/text.go diff --git a/client/Dockerfile b/client/Dockerfile index 64d5ba04f..53e4555ef 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -17,6 +17,7 @@ ENV \ NETBIRD_BIN="/usr/local/bin/netbird" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ + NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 69d00aaf2..706bf40de 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -23,6 +23,7 @@ ENV \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_DISABLE_DNS="true" \ + NB_ENABLE_CAPTURE="false" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="30" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/cmd/capture.go b/client/cmd/capture.go new file mode 100644 index 000000000..95caaa5cd --- /dev/null +++ b/client/cmd/capture.go @@ -0,0 +1,196 @@ +package cmd + +import ( + "context" + "fmt" + "io" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "github.com/hashicorp/go-multierror" + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" +) + +var captureCmd = &cobra.Command{ + Use: "capture", + Short: "Capture packets on the WireGuard interface", + Long: `Captures decrypted packets flowing through the WireGuard interface. + +Default output is human-readable text. Use --pcap or --output for pcap binary. +Requires --enable-capture to be set at service install or reconfigure time. + +Examples: + netbird debug capture + netbird debug capture host 100.64.0.1 and port 443 + netbird debug capture tcp + netbird debug capture icmp + netbird debug capture src host 10.0.0.1 and dst port 80 + netbird debug capture -o capture.pcap + netbird debug capture --pcap | tshark -r - + netbird debug capture --pcap | tcpdump -r - -n`, + Args: cobra.ArbitraryArgs, + RunE: runCapture, +} + +func init() { + debugCmd.AddCommand(captureCmd) + + captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") + captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length") + captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)") + captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)") + captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)") + captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout") +} + +func runCapture(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + cmd.PrintErrf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + + req, err := buildCaptureRequest(cmd, args) + if err != nil { + return err + } + + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + stream, err := client.StartCapture(ctx, req) + if err != nil { + return handleCaptureError(err) + } + + // First Recv is the empty acceptance message from the server. If the + // device is unavailable (kernel WG, not connected, capture disabled), + // the server returns an error instead. + if _, err := stream.Recv(); err != nil { + return handleCaptureError(err) + } + + out, cleanup, err := captureOutput(cmd) + if err != nil { + return err + } + + if req.TextOutput { + cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n") + } else { + cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n") + } + + streamErr := streamCapture(ctx, cmd, stream, out) + cleanupErr := cleanup() + if streamErr != nil { + return streamErr + } + return cleanupErr +} + +func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) { + req := &proto.StartCaptureRequest{} + + if len(args) > 0 { + expr := strings.Join(args, " ") + if _, err := capture.ParseFilter(expr); err != nil { + return nil, fmt.Errorf("invalid filter: %w", err) + } + req.FilterExpr = expr + } + + if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 { + req.SnapLen = snap + } + if d, _ := cmd.Flags().GetDuration("duration"); d != 0 { + if d < 0 { + return nil, fmt.Errorf("duration must not be negative") + } + req.Duration = durationpb.New(d) + } + req.Verbose, _ = cmd.Flags().GetBool("verbose") + req.Ascii, _ = cmd.Flags().GetBool("ascii") + + outPath, _ := cmd.Flags().GetString("output") + forcePcap, _ := cmd.Flags().GetBool("pcap") + req.TextOutput = !forcePcap && outPath == "" + + return req, nil +} + +func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error { + for { + pkt, err := stream.Recv() + if err != nil { + if ctx.Err() != nil { + cmd.PrintErrf("\nCapture stopped.\n") + return nil //nolint:nilerr // user interrupted + } + if err == io.EOF { + cmd.PrintErrf("\nCapture finished.\n") + return nil + } + return handleCaptureError(err) + } + if _, err := out.Write(pkt.GetData()); err != nil { + return fmt.Errorf("write output: %w", err) + } + } +} + +// captureOutput returns the writer for capture data and a cleanup function +// that finalizes the file. Errors from the cleanup must be propagated. +func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) { + outPath, _ := cmd.Flags().GetString("output") + if outPath == "" { + return os.Stdout, func() error { return nil }, nil + } + + f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") + if err != nil { + return nil, nil, fmt.Errorf("create output file: %w", err) + } + tmpPath := f.Name() + return f, func() error { + var merr *multierror.Error + if err := f.Close(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err)) + } + fi, statErr := os.Stat(tmpPath) + if statErr != nil || fi.Size() == 0 { + if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) { + merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr)) + } + return nberrors.FormatErrorOrNil(merr) + } + if err := os.Rename(tmpPath, outPath); err != nil { + merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err)) + return nberrors.FormatErrorOrNil(merr) + } + cmd.PrintErrf("Wrote %s\n", outPath) + return nberrors.FormatErrorOrNil(merr) + }, nil +} + +func handleCaptureError(err error) error { + if s, ok := status.FromError(err); ok { + return fmt.Errorf("%s", s.Message()) + } + return err +} diff --git a/client/cmd/debug.go b/client/cmd/debug.go index e3d3afe5f..2a8cdc887 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/debug" @@ -239,11 +240,50 @@ func runForDuration(cmd *cobra.Command, args []string) error { }() } + captureStarted := false + if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture { + captureTimeout := duration + 30*time.Second + const maxBundleCapture = 10 * time.Minute + if captureTimeout > maxBundleCapture { + captureTimeout = maxBundleCapture + } + _, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(captureTimeout), + }) + if err != nil { + cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message()) + } else { + captureStarted = true + cmd.Println("Packet capture started.") + // Safety: always stop on exit, even if the normal stop below runs too. + defer func() { + if captureStarted { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + cmd.PrintErrf("Failed to stop packet capture: %v\n", err) + } + } + }() + } + } + if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { return waitErr } cmd.Println("\nDuration completed") + if captureStarted { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + cmd.PrintErrf("Failed to stop packet capture: %v\n", err) + } else { + captureStarted = false + cmd.Println("Packet capture stopped.") + } + } + if cpuProfilingStarted { if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) @@ -416,4 +456,5 @@ func init() { forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") + forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle") } diff --git a/client/cmd/root.go b/client/cmd/root.go index c872fe9f6..29d4328a1 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -75,6 +75,7 @@ var ( mtu uint16 profilesDisabled bool updateSettingsDisabled bool + captureEnabled bool networksDisabled bool rootCmd = &cobra.Command{ diff --git a/client/cmd/service.go b/client/cmd/service.go index f1123ce8c..56d8a8726 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -44,6 +44,7 @@ func init() { serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") + serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture") serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 0943b6184..88121c067 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 5ada6f633..2d45fa063 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -59,6 +59,10 @@ func buildServiceArguments() []string { args = append(args, "--disable-update-settings") } + if captureEnabled { + args = append(args, "--enable-capture") + } + if networksDisabled { args = append(args, "--disable-networks") } diff --git a/client/cmd/service_params.go b/client/cmd/service_params.go index 5a86aebc6..192e0ac60 100644 --- a/client/cmd/service_params.go +++ b/client/cmd/service_params.go @@ -28,6 +28,7 @@ type serviceParams struct { LogFiles []string `json:"log_files,omitempty"` DisableProfiles bool `json:"disable_profiles,omitempty"` DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` + EnableCapture bool `json:"enable_capture,omitempty"` DisableNetworks bool `json:"disable_networks,omitempty"` ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` } @@ -79,6 +80,7 @@ func currentServiceParams() *serviceParams { LogFiles: logFiles, DisableProfiles: profilesDisabled, DisableUpdateSettings: updateSettingsDisabled, + EnableCapture: captureEnabled, DisableNetworks: networksDisabled, } @@ -144,6 +146,10 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) { updateSettingsDisabled = params.DisableUpdateSettings } + if !serviceCmd.PersistentFlags().Changed("enable-capture") { + captureEnabled = params.EnableCapture + } + if !serviceCmd.PersistentFlags().Changed("disable-networks") { networksDisabled = params.DisableNetworks } diff --git a/client/cmd/service_params_test.go b/client/cmd/service_params_test.go index 7e04e5abe..f338c12f4 100644 --- a/client/cmd/service_params_test.go +++ b/client/cmd/service_params_test.go @@ -535,6 +535,7 @@ func fieldToGlobalVar(field string) string { "LogFiles": "logFiles", "DisableProfiles": "profilesDisabled", "DisableUpdateSettings": "updateSettingsDisabled", + "EnableCapture": "captureEnabled", "DisableNetworks": "networksDisabled", "ServiceEnvVars": "serviceEnvVars", } diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index fd1007bb4..c24965e8d 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -160,7 +160,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - "", "", false, false, false) + "", "", false, false, false, false) if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/embed/capture.go b/client/embed/capture.go new file mode 100644 index 000000000..30f9b496f --- /dev/null +++ b/client/embed/capture.go @@ -0,0 +1,65 @@ +package embed + +import ( + "io" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/util/capture" +) + +// CaptureOptions configures a packet capture session. +type CaptureOptions struct { + // Output receives pcap-formatted data. Nil disables pcap output. + Output io.Writer + // TextOutput receives human-readable packet summaries. Nil disables text output. + TextOutput io.Writer + // Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443"). + // Empty captures all packets. + Filter string + // Verbose adds seq/ack, TTL, window, and total length to text output. + Verbose bool + // ASCII dumps transport payload as printable ASCII after each packet line. + ASCII bool +} + +// CaptureStats reports capture session counters. +type CaptureStats struct { + Packets int64 + Bytes int64 + Dropped int64 +} + +// CaptureSession represents an active packet capture. Call Stop to end the +// capture and flush buffered packets. +type CaptureSession struct { + sess *capture.Session + engine *internal.Engine +} + +// Stop ends the capture, flushes remaining packets, and detaches from the device. +// Safe to call multiple times. +func (cs *CaptureSession) Stop() { + if cs.engine != nil { + _ = cs.engine.SetCapture(nil) + cs.engine = nil + } + if cs.sess != nil { + cs.sess.Stop() + } +} + +// Stats returns current capture counters. +func (cs *CaptureSession) Stats() CaptureStats { + s := cs.sess.Stats() + return CaptureStats{ + Packets: s.Packets, + Bytes: s.Bytes, + Dropped: s.Dropped, + } +} + +// Done returns a channel that is closed when the capture's writer goroutine +// has fully exited and all buffered packets have been flushed. +func (cs *CaptureSession) Done() <-chan struct{} { + return cs.sess.Done() +} diff --git a/client/embed/embed.go b/client/embed/embed.go index 88f7e541c..baa1d94d6 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" mgmProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/util/capture" ) var ( @@ -65,7 +66,7 @@ type Options struct { PrivateKey string // ManagementURL overrides the default management server URL ManagementURL string - // PreSharedKey is the pre-shared key for the WireGuard interface + // PreSharedKey is the pre-shared key for the tunnel interface PreSharedKey string // LogOutput is the output destination for logs (defaults to os.Stderr if nil) LogOutput io.Writer @@ -81,9 +82,9 @@ type Options struct { DisableClientRoutes bool // BlockInbound blocks all inbound connections from peers BlockInbound bool - // WireguardPort is the port for the WireGuard interface. Use 0 for a random port. + // WireguardPort is the port for the tunnel interface. Use 0 for a random port. WireguardPort *int - // MTU is the MTU for the WireGuard interface. + // MTU is the MTU for the tunnel interface. // Valid values are in the range 576..8192 bytes. // If non-nil, this value overrides any value stored in the config file. // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. @@ -469,6 +470,52 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { return sshcommon.VerifyHostKey(storedKey, key, peerAddress) } +// StartCapture begins capturing packets on this client's tunnel device. +// Only one capture can be active at a time; starting a new one stops the previous. +// Call StopCapture (or CaptureSession.Stop) to end it. +func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) { + engine, err := c.getEngine() + if err != nil { + return nil, err + } + + var matcher capture.Matcher + if opts.Filter != "" { + m, err := capture.ParseFilter(opts.Filter) + if err != nil { + return nil, fmt.Errorf("parse filter: %w", err) + } + matcher = m + } + + sess, err := capture.NewSession(capture.Options{ + Output: opts.Output, + TextOutput: opts.TextOutput, + Matcher: matcher, + Verbose: opts.Verbose, + ASCII: opts.ASCII, + }) + if err != nil { + return nil, fmt.Errorf("create capture session: %w", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + return nil, fmt.Errorf("set capture: %w", err) + } + + return &CaptureSession{sess: sess, engine: engine}, nil +} + +// StopCapture stops the active capture session if one is running. +func (c *Client) StopCapture() error { + engine, err := c.getEngine() + if err != nil { + return err + } + return engine.SetCapture(nil) +} + // getEngine safely retrieves the engine from the client with proper locking. // Returns ErrClientNotStarted if the client is not started. // Returns ErrEngineNotStarted if the engine is not available. diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 24b3d0167..3787e63a8 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -115,12 +115,13 @@ type Manager struct { localipmanager *localIPManager - udpTracker *conntrack.UDPTracker - icmpTracker *conntrack.ICMPTracker - tcpTracker *conntrack.TCPTracker - forwarder atomic.Pointer[forwarder.Forwarder] - logger *nblog.Logger - flowLogger nftypes.FlowLogger + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker + forwarder atomic.Pointer[forwarder.Forwarder] + pendingCapture atomic.Pointer[forwarder.PacketCapture] + logger *nblog.Logger + flowLogger nftypes.FlowLogger blockRule firewall.Rule @@ -351,6 +352,19 @@ func (m *Manager) determineRouting() error { return nil } +// SetPacketCapture sets or clears packet capture on the forwarder endpoint. +// This captures outbound response packets that bypass the FilteredDevice in netstack mode. +func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) { + if pc == nil { + m.pendingCapture.Store(nil) + } else { + m.pendingCapture.Store(&pc) + } + if fwder := m.forwarder.Load(); fwder != nil { + fwder.SetCapture(pc) + } +} + // initForwarder initializes the forwarder, it disables routing on errors func (m *Manager) initForwarder() error { if m.forwarder.Load() != nil { @@ -372,6 +386,11 @@ func (m *Manager) initForwarder() error { m.forwarder.Store(forwarder) + // Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture. + if pc := m.pendingCapture.Load(); pc != nil { + forwarder.SetCapture(*pc) + } + log.Debug("forwarder initialized") return nil @@ -614,6 +633,7 @@ func (m *Manager) resetState() { } if fwder := m.forwarder.Load(); fwder != nil { + fwder.SetCapture(nil) fwder.Stop() } diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index 692a24140..96ab89af8 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -12,12 +12,19 @@ import ( nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) +// PacketCapture captures raw packets for debugging. Implementations must be +// safe for concurrent use and must not block. +type PacketCapture interface { + Offer(data []byte, outbound bool) +} + // endpoint implements stack.LinkEndpoint and handles integration with the wireguard device type endpoint struct { logger *nblog.Logger dispatcher stack.NetworkDispatcher device *wgdevice.Device mtu atomic.Uint32 + capture atomic.Pointer[PacketCapture] } func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { @@ -54,13 +61,17 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) continue } - // Send the packet through WireGuard + pktBytes := data.AsSlice() + address := netHeader.DestinationAddress() - err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) - if err != nil { + if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil { e.logger.Error1("CreateOutboundPacket: %v", err) continue } + + if pc := e.capture.Load(); pc != nil { + (*pc).Offer(pktBytes, true) + } written++ } diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index d17c3cd5c..925273f24 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -139,6 +139,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow return f, nil } +// SetCapture sets or clears the packet capture on the forwarder endpoint. +// This captures outbound packets that bypass the FilteredDevice (netstack forwarding). +func (f *Forwarder) SetCapture(pc PacketCapture) { + if pc == nil { + f.endpoint.capture.Store(nil) + return + } + f.endpoint.capture.Store(&pc) +} + func (f *Forwarder) InjectIncomingPacket(payload []byte) error { if len(payload) < header.IPv4MinimumSize { return fmt.Errorf("packet too small: %d bytes", len(payload)) diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index cb3db325d..217423901 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -270,5 +270,9 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload [] return 0 } + if pc := f.endpoint.capture.Load(); pc != nil { + (*pc).Offer(fullPacket, true) + } + return len(fullPacket) } diff --git a/client/iface/device/device_filter.go b/client/iface/device/device_filter.go index 4357d1916..fc1c65efa 100644 --- a/client/iface/device/device_filter.go +++ b/client/iface/device/device_filter.go @@ -3,6 +3,7 @@ package device import ( "net/netip" "sync" + "sync/atomic" "golang.zx2c4.com/wireguard/tun" ) @@ -28,11 +29,20 @@ type PacketFilter interface { SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) } +// PacketCapture captures raw packets for debugging. Implementations must be +// safe for concurrent use and must not block. +type PacketCapture interface { + // Offer submits a packet for capture. outbound is true for packets + // leaving the host (Read path), false for packets arriving (Write path). + Offer(data []byte, outbound bool) +} + // FilteredDevice to override Read or Write of packets type FilteredDevice struct { tun.Device filter PacketFilter + capture atomic.Pointer[PacketCapture] mutex sync.RWMutex closeOnce sync.Once } @@ -63,20 +73,25 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er if n, err = d.Device.Read(bufs, sizes, offset); err != nil { return 0, err } + d.mutex.RLock() filter := d.filter d.mutex.RUnlock() - if filter == nil { - return + if filter != nil { + for i := 0; i < n; i++ { + if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { + bufs = append(bufs[:i], bufs[i+1:]...) + sizes = append(sizes[:i], sizes[i+1:]...) + n-- + i-- + } + } } - for i := 0; i < n; i++ { - if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { - bufs = append(bufs[:i], bufs[i+1:]...) - sizes = append(sizes[:i], sizes[i+1:]...) - n-- - i-- + if pc := d.capture.Load(); pc != nil { + for i := 0; i < n; i++ { + (*pc).Offer(bufs[i][offset:offset+sizes[i]], true) } } @@ -85,6 +100,13 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er // Write wraps write method with filtering feature func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { + // Capture before filtering so dropped packets are still visible in captures. + if pc := d.capture.Load(); pc != nil { + for _, buf := range bufs { + (*pc).Offer(buf[offset:], false) + } + } + d.mutex.RLock() filter := d.filter d.mutex.RUnlock() @@ -96,9 +118,10 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { filteredBufs := make([][]byte, 0, len(bufs)) dropped := 0 for _, buf := range bufs { - if !filter.FilterInbound(buf[offset:], len(buf)) { - filteredBufs = append(filteredBufs, buf) + if filter.FilterInbound(buf[offset:], len(buf)) { dropped++ + } else { + filteredBufs = append(filteredBufs, buf) } } @@ -113,3 +136,14 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) { d.filter = filter d.mutex.Unlock() } + +// SetCapture sets or clears the packet capture sink. Pass nil to disable. +// Uses atomic store so the hot path (Read/Write) is a single pointer load +// with no locking overhead when capture is off. +func (d *FilteredDevice) SetCapture(pc PacketCapture) { + if pc == nil { + d.capture.Store(nil) + return + } + d.capture.Store(&pc) +} diff --git a/client/iface/device/device_filter_test.go b/client/iface/device/device_filter_test.go index eef783542..8fb16ca8d 100644 --- a/client/iface/device/device_filter_test.go +++ b/client/iface/device/device_filter_test.go @@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) { t.Errorf("unexpected error: %v", err) return } - if n != 0 { + if n != 1 { t.Errorf("expected n=1, got %d", n) return } diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index bddb9a69e..90560d028 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -61,6 +61,7 @@ allocs.prof: Allocations profiling information. threadcreate.prof: Thread creation profiling information. cpu.prof: CPU profiling information. stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. +capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data. Anonymization Process @@ -234,6 +235,7 @@ type BundleGenerator struct { logPath string tempDir string cpuProfile []byte + capturePath string refreshStatus func() // Optional callback to refresh status before bundle generation clientMetrics MetricsExporter @@ -257,7 +259,8 @@ type GeneratorDependencies struct { LogPath string TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used. CPUProfile []byte - RefreshStatus func() // Optional callback to refresh status before bundle generation + CapturePath string + RefreshStatus func() ClientMetrics MetricsExporter } @@ -277,6 +280,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen logPath: deps.LogPath, tempDir: deps.TempDir, cpuProfile: deps.CPUProfile, + capturePath: deps.CapturePath, refreshStatus: deps.RefreshStatus, clientMetrics: deps.ClientMetrics, @@ -346,6 +350,10 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add CPU profile to debug bundle: %v", err) } + if err := g.addCaptureFile(); err != nil { + log.Errorf("failed to add capture file to debug bundle: %v", err) + } + if err := g.addStackTrace(); err != nil { log.Errorf("failed to add stack trace to debug bundle: %v", err) } @@ -669,6 +677,29 @@ func (g *BundleGenerator) addCPUProfile() error { return nil } +func (g *BundleGenerator) addCaptureFile() error { + if g.capturePath == "" { + return nil + } + + if g.anonymize { + log.Info("skipping capture file in anonymized bundle (contains raw packet data)") + return nil + } + + f, err := os.Open(g.capturePath) + if err != nil { + return fmt.Errorf("open capture file: %w", err) + } + defer f.Close() + + if err := g.addFileToZip(f, "capture.pcap"); err != nil { + return fmt.Errorf("add capture file to zip: %w", err) + } + + return nil +} + func (g *BundleGenerator) addStackTrace() error { buf := make([]byte, 5242880) // 5 MB buffer n := runtime.Stack(buf, true) diff --git a/client/internal/engine.go b/client/internal/engine.go index 351e4bfe9..8c9553e52 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -28,6 +28,7 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/firewalld" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" @@ -68,6 +69,7 @@ import ( signal "github.com/netbirdio/netbird/shared/signal/client" sProto "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/capture" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -218,6 +220,8 @@ type Engine struct { portForwardManager *portforward.Manager srWatcher *guard.SRWatcher + afpacketCapture *capture.AFPacketCapture + // Sync response persistence (protected by syncRespMux) syncRespMux sync.RWMutex persistSyncResponse bool @@ -1703,6 +1707,11 @@ func (e *Engine) parseNATExternalIPMappings() []string { } func (e *Engine) close() { + if e.afpacketCapture != nil { + e.afpacketCapture.Stop() + e.afpacketCapture = nil + } + log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { @@ -2168,6 +2177,62 @@ func (e *Engine) Address() (netip.Addr, error) { return e.wgInterface.Address().IP, nil } +// SetCapture sets or clears packet capture on the WireGuard device. +// On userspace WireGuard, it taps the FilteredDevice directly. +// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture. +// Pass nil to disable capture. +func (e *Engine) SetCapture(pc device.PacketCapture) error { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + intf := e.wgInterface + if intf == nil { + return errors.New("wireguard interface not initialized") + } + + if e.afpacketCapture != nil { + e.afpacketCapture.Stop() + e.afpacketCapture = nil + } + + dev := intf.GetDevice() + if dev != nil { + dev.SetCapture(pc) + e.setForwarderCapture(pc) + return nil + } + + // Kernel mode: no FilteredDevice. Use AF_PACKET on Linux. + if pc == nil { + return nil + } + sess, ok := pc.(*capture.Session) + if !ok { + return errors.New("filtered device not available and AF_PACKET requires *capture.Session") + } + + afc := capture.NewAFPacketCapture(intf.Name(), sess) + if err := afc.Start(); err != nil { + return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err) + } + e.afpacketCapture = afc + return nil +} + +// setForwarderCapture propagates capture to the USP filter's forwarder endpoint. +// This captures outbound response packets that bypass the FilteredDevice in netstack mode. +func (e *Engine) setForwarderCapture(pc device.PacketCapture) { + if e.firewall == nil { + return + } + type forwarderCapturer interface { + SetPacketCapture(pc forwarder.PacketCapture) + } + if fc, ok := e.firewall.(forwarderCapturer); ok { + fc.SetPacketCapture(pc) + } +} + func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { if e.firewall == nil { log.Warn("firewall is disabled, not updating forwarding rules") diff --git a/client/internal/lazyconn/manager/manager.go b/client/internal/lazyconn/manager/manager.go index b6b3c6091..fc47bda39 100644 --- a/client/internal/lazyconn/manager/manager.go +++ b/client/internal/lazyconn/manager/manager.go @@ -6,7 +6,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn/activity" @@ -91,8 +90,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) { m.routesMu.Lock() defer m.routesMu.Unlock() - maps.Clear(m.peerToHAGroups) - maps.Clear(m.haGroupToPeers) + clear(m.peerToHAGroups) + clear(m.haGroupToPeers) for haUniqueID, routes := range haMap { var peers []string diff --git a/client/internal/netflow/store/memory.go b/client/internal/netflow/store/memory.go index b695a0a12..a44505e96 100644 --- a/client/internal/netflow/store/memory.go +++ b/client/internal/netflow/store/memory.go @@ -3,8 +3,6 @@ package store import ( "sync" - "golang.org/x/exp/maps" - "github.com/google/uuid" "github.com/netbirdio/netbird/client/internal/netflow/types" @@ -30,7 +28,7 @@ func (m *Memory) StoreEvent(event *types.Event) { func (m *Memory) Close() { m.mux.Lock() defer m.mux.Unlock() - maps.Clear(m.events) + clear(m.events) } func (m *Memory) GetEvents() []*types.Event { diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 61c8bbc79..30afc013b 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/hashicorp/go-multierror" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" @@ -44,8 +43,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) for _, r := range allRoutes { rs.deselectedRoutes[r] = struct{}{} } @@ -78,8 +77,8 @@ func (rs *RouteSelector) SelectAllRoutes() { if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) } // DeselectRoutes removes specific routes from the selection. @@ -116,8 +115,8 @@ func (rs *RouteSelector) DeselectAllRoutes() { if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } - maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) + clear(rs.deselectedRoutes) + clear(rs.selectedRoutes) } // IsSelected checks if a specific route is selected. diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 31658d5a1..11e7877f2 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -5847,6 +5847,288 @@ func (x *ExposeServiceReady) GetPortAutoAssigned() bool { return false } +type StartCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + TextOutput bool `protobuf:"varint,1,opt,name=text_output,json=textOutput,proto3" json:"text_output,omitempty"` + SnapLen uint32 `protobuf:"varint,2,opt,name=snap_len,json=snapLen,proto3" json:"snap_len,omitempty"` + Duration *durationpb.Duration `protobuf:"bytes,3,opt,name=duration,proto3" json:"duration,omitempty"` + FilterExpr string `protobuf:"bytes,4,opt,name=filter_expr,json=filterExpr,proto3" json:"filter_expr,omitempty"` + Verbose bool `protobuf:"varint,5,opt,name=verbose,proto3" json:"verbose,omitempty"` + Ascii bool `protobuf:"varint,6,opt,name=ascii,proto3" json:"ascii,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartCaptureRequest) Reset() { + *x = StartCaptureRequest{} + mi := &file_daemon_proto_msgTypes[88] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartCaptureRequest) ProtoMessage() {} + +func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[88] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartCaptureRequest.ProtoReflect.Descriptor instead. +func (*StartCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{88} +} + +func (x *StartCaptureRequest) GetTextOutput() bool { + if x != nil { + return x.TextOutput + } + return false +} + +func (x *StartCaptureRequest) GetSnapLen() uint32 { + if x != nil { + return x.SnapLen + } + return 0 +} + +func (x *StartCaptureRequest) GetDuration() *durationpb.Duration { + if x != nil { + return x.Duration + } + return nil +} + +func (x *StartCaptureRequest) GetFilterExpr() string { + if x != nil { + return x.FilterExpr + } + return "" +} + +func (x *StartCaptureRequest) GetVerbose() bool { + if x != nil { + return x.Verbose + } + return false +} + +func (x *StartCaptureRequest) GetAscii() bool { + if x != nil { + return x.Ascii + } + return false +} + +type CapturePacket struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CapturePacket) Reset() { + *x = CapturePacket{} + mi := &file_daemon_proto_msgTypes[89] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CapturePacket) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CapturePacket) ProtoMessage() {} + +func (x *CapturePacket) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[89] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CapturePacket.ProtoReflect.Descriptor instead. +func (*CapturePacket) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{89} +} + +func (x *CapturePacket) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type StartBundleCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // timeout auto-stops the capture after this duration. + // Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum. + Timeout *durationpb.Duration `protobuf:"bytes,1,opt,name=timeout,proto3" json:"timeout,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartBundleCaptureRequest) Reset() { + *x = StartBundleCaptureRequest{} + mi := &file_daemon_proto_msgTypes[90] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartBundleCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartBundleCaptureRequest) ProtoMessage() {} + +func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[90] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartBundleCaptureRequest.ProtoReflect.Descriptor instead. +func (*StartBundleCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{90} +} + +func (x *StartBundleCaptureRequest) GetTimeout() *durationpb.Duration { + if x != nil { + return x.Timeout + } + return nil +} + +type StartBundleCaptureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartBundleCaptureResponse) Reset() { + *x = StartBundleCaptureResponse{} + mi := &file_daemon_proto_msgTypes[91] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartBundleCaptureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartBundleCaptureResponse) ProtoMessage() {} + +func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[91] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartBundleCaptureResponse.ProtoReflect.Descriptor instead. +func (*StartBundleCaptureResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{91} +} + +type StopBundleCaptureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopBundleCaptureRequest) Reset() { + *x = StopBundleCaptureRequest{} + mi := &file_daemon_proto_msgTypes[92] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopBundleCaptureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopBundleCaptureRequest) ProtoMessage() {} + +func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[92] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopBundleCaptureRequest.ProtoReflect.Descriptor instead. +func (*StopBundleCaptureRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{92} +} + +type StopBundleCaptureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopBundleCaptureResponse) Reset() { + *x = StopBundleCaptureResponse{} + mi := &file_daemon_proto_msgTypes[93] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopBundleCaptureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopBundleCaptureResponse) ProtoMessage() {} + +func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[93] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopBundleCaptureResponse.ProtoReflect.Descriptor instead. +func (*StopBundleCaptureResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{93} +} + type PortInfo_Range struct { state protoimpl.MessageState `protogen:"open.v1"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` @@ -5857,7 +6139,7 @@ type PortInfo_Range struct { func (x *PortInfo_Range) Reset() { *x = PortInfo_Range{} - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[95] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5869,7 +6151,7 @@ func (x *PortInfo_Range) String() string { func (*PortInfo_Range) ProtoMessage() {} func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[89] + mi := &file_daemon_proto_msgTypes[95] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6410,7 +6692,23 @@ const file_daemon_proto_rawDesc = "" + "\vservice_url\x18\x02 \x01(\tR\n" + "serviceUrl\x12\x16\n" + "\x06domain\x18\x03 \x01(\tR\x06domain\x12,\n" + - "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned*b\n" + + "\x12port_auto_assigned\x18\x04 \x01(\bR\x10portAutoAssigned\"\xd9\x01\n" + + "\x13StartCaptureRequest\x12\x1f\n" + + "\vtext_output\x18\x01 \x01(\bR\n" + + "textOutput\x12\x19\n" + + "\bsnap_len\x18\x02 \x01(\rR\asnapLen\x125\n" + + "\bduration\x18\x03 \x01(\v2\x19.google.protobuf.DurationR\bduration\x12\x1f\n" + + "\vfilter_expr\x18\x04 \x01(\tR\n" + + "filterExpr\x12\x18\n" + + "\averbose\x18\x05 \x01(\bR\averbose\x12\x14\n" + + "\x05ascii\x18\x06 \x01(\bR\x05ascii\"#\n" + + "\rCapturePacket\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\"P\n" + + "\x19StartBundleCaptureRequest\x123\n" + + "\atimeout\x18\x01 \x01(\v2\x19.google.protobuf.DurationR\atimeout\"\x1c\n" + + "\x1aStartBundleCaptureResponse\"\x1a\n" + + "\x18StopBundleCaptureRequest\"\x1b\n" + + "\x19StopBundleCaptureResponse*b\n" + "\bLogLevel\x12\v\n" + "\aUNKNOWN\x10\x00\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" + @@ -6428,7 +6726,7 @@ const file_daemon_proto_rawDesc = "" + "\n" + "EXPOSE_UDP\x10\x03\x12\x0e\n" + "\n" + - "EXPOSE_TLS\x10\x042\xac\x15\n" + + "EXPOSE_TLS\x10\x042\xaf\x17\n" + "\rDaemonService\x126\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + @@ -6449,7 +6747,10 @@ const file_daemon_proto_rawDesc = "" + "CleanState\x12\x19.daemon.CleanStateRequest\x1a\x1a.daemon.CleanStateResponse\"\x00\x12H\n" + "\vDeleteState\x12\x1a.daemon.DeleteStateRequest\x1a\x1b.daemon.DeleteStateResponse\"\x00\x12u\n" + "\x1aSetSyncResponsePersistence\x12).daemon.SetSyncResponsePersistenceRequest\x1a*.daemon.SetSyncResponsePersistenceResponse\"\x00\x12H\n" + - "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12D\n" + + "\vTracePacket\x12\x1a.daemon.TracePacketRequest\x1a\x1b.daemon.TracePacketResponse\"\x00\x12F\n" + + "\fStartCapture\x12\x1b.daemon.StartCaptureRequest\x1a\x15.daemon.CapturePacket\"\x000\x01\x12]\n" + + "\x12StartBundleCapture\x12!.daemon.StartBundleCaptureRequest\x1a\".daemon.StartBundleCaptureResponse\"\x00\x12Z\n" + + "\x11StopBundleCapture\x12 .daemon.StopBundleCaptureRequest\x1a!.daemon.StopBundleCaptureResponse\"\x00\x12D\n" + "\x0fSubscribeEvents\x12\x18.daemon.SubscribeRequest\x1a\x13.daemon.SystemEvent\"\x000\x01\x12B\n" + "\tGetEvents\x12\x18.daemon.GetEventsRequest\x1a\x19.daemon.GetEventsResponse\"\x00\x12N\n" + "\rSwitchProfile\x12\x1c.daemon.SwitchProfileRequest\x1a\x1d.daemon.SwitchProfileResponse\"\x00\x12B\n" + @@ -6483,7 +6784,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 97) var file_daemon_proto_goTypes = []any{ (LogLevel)(0), // 0: daemon.LogLevel (ExposeProtocol)(0), // 1: daemon.ExposeProtocol @@ -6577,125 +6878,139 @@ var file_daemon_proto_goTypes = []any{ (*ExposeServiceRequest)(nil), // 89: daemon.ExposeServiceRequest (*ExposeServiceEvent)(nil), // 90: daemon.ExposeServiceEvent (*ExposeServiceReady)(nil), // 91: daemon.ExposeServiceReady - nil, // 92: daemon.Network.ResolvedIPsEntry - (*PortInfo_Range)(nil), // 93: daemon.PortInfo.Range - nil, // 94: daemon.SystemEvent.MetadataEntry - (*durationpb.Duration)(nil), // 95: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 96: google.protobuf.Timestamp + (*StartCaptureRequest)(nil), // 92: daemon.StartCaptureRequest + (*CapturePacket)(nil), // 93: daemon.CapturePacket + (*StartBundleCaptureRequest)(nil), // 94: daemon.StartBundleCaptureRequest + (*StartBundleCaptureResponse)(nil), // 95: daemon.StartBundleCaptureResponse + (*StopBundleCaptureRequest)(nil), // 96: daemon.StopBundleCaptureRequest + (*StopBundleCaptureResponse)(nil), // 97: daemon.StopBundleCaptureResponse + nil, // 98: daemon.Network.ResolvedIPsEntry + (*PortInfo_Range)(nil), // 99: daemon.PortInfo.Range + nil, // 100: daemon.SystemEvent.MetadataEntry + (*durationpb.Duration)(nil), // 101: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 102: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 95, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 25, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 96, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 96, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 95, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration - 23, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo - 20, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState - 19, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState - 18, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState - 17, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState - 21, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState - 22, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 55, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent - 24, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState - 31, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 92, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry - 93, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range - 32, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo - 32, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo - 33, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule - 0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel - 0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 41, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State - 50, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags - 52, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage - 2, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity - 3, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category - 96, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp - 94, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry - 55, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent - 95, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration - 68, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile - 1, // 32: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol - 91, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady - 30, // 34: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 5, // 35: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 7, // 36: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 9, // 37: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 11, // 38: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 13, // 39: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 15, // 40: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 26, // 41: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 28, // 42: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 28, // 43: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 4, // 44: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest - 35, // 45: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 37, // 46: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 39, // 47: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 42, // 48: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 44, // 49: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 46, // 50: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 48, // 51: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest - 51, // 52: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest - 54, // 53: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest - 56, // 54: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest - 58, // 55: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest - 60, // 56: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest - 62, // 57: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest - 64, // 58: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest - 66, // 59: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest - 69, // 60: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest - 71, // 61: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest - 73, // 62: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest - 75, // 63: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest - 77, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest - 79, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest - 81, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest - 83, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest - 85, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest - 87, // 69: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest - 89, // 70: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest - 6, // 71: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 8, // 72: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 10, // 73: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 12, // 74: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 14, // 75: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 16, // 76: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 27, // 77: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 29, // 78: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 29, // 79: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 34, // 80: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse - 36, // 81: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 38, // 82: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 40, // 83: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 43, // 84: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 45, // 85: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 47, // 86: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 49, // 87: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse - 53, // 88: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse - 55, // 89: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent - 57, // 90: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse - 59, // 91: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse - 61, // 92: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse - 63, // 93: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse - 65, // 94: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse - 67, // 95: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse - 70, // 96: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse - 72, // 97: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse - 74, // 98: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse - 76, // 99: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse - 78, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse - 80, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse - 82, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse - 84, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse - 86, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse - 88, // 105: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse - 90, // 106: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent - 71, // [71:107] is the sub-list for method output_type - 35, // [35:71] is the sub-list for method input_type - 35, // [35:35] is the sub-list for extension type_name - 35, // [35:35] is the sub-list for extension extendee - 0, // [0:35] is the sub-list for field type_name + 101, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 25, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus + 102, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 102, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 101, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 23, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo + 20, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState + 19, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState + 18, // 8: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState + 17, // 9: daemon.FullStatus.peers:type_name -> daemon.PeerState + 21, // 10: daemon.FullStatus.relays:type_name -> daemon.RelayState + 22, // 11: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState + 55, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent + 24, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState + 31, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 98, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 99, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range + 32, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo + 32, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo + 33, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule + 0, // 20: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel + 0, // 21: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel + 41, // 22: daemon.ListStatesResponse.states:type_name -> daemon.State + 50, // 23: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 52, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 2, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity + 3, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category + 102, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp + 100, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry + 55, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent + 101, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 68, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile + 1, // 32: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol + 91, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady + 101, // 34: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration + 101, // 35: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration + 30, // 36: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 5, // 37: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 7, // 38: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 9, // 39: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 11, // 40: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 13, // 41: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 15, // 42: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 26, // 43: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 28, // 44: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 28, // 45: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 4, // 46: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest + 35, // 47: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 37, // 48: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 39, // 49: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 42, // 50: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 44, // 51: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 46, // 52: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 48, // 53: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest + 51, // 54: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 92, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest + 94, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest + 96, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest + 54, // 58: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest + 56, // 59: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest + 58, // 60: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest + 60, // 61: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest + 62, // 62: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest + 64, // 63: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest + 66, // 64: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest + 69, // 65: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest + 71, // 66: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest + 73, // 67: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest + 75, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest + 77, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest + 79, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest + 81, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest + 83, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest + 85, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest + 87, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest + 89, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest + 6, // 76: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 8, // 77: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 10, // 78: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 12, // 79: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 14, // 80: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 16, // 81: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 27, // 82: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 29, // 83: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 29, // 84: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 34, // 85: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse + 36, // 86: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 38, // 87: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 40, // 88: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 43, // 89: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 45, // 90: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 47, // 91: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 49, // 92: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse + 53, // 93: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 93, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket + 95, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse + 97, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse + 55, // 97: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent + 57, // 98: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse + 59, // 99: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse + 61, // 100: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse + 63, // 101: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse + 65, // 102: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse + 67, // 103: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse + 70, // 104: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse + 72, // 105: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse + 74, // 106: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse + 76, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse + 78, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse + 80, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse + 82, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse + 84, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse + 86, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse + 88, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse + 90, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent + 76, // [76:115] is the sub-list for method output_type + 37, // [37:76] is the sub-list for method input_type + 37, // [37:37] is the sub-list for extension type_name + 37, // [37:37] is the sub-list for extension extendee + 0, // [0:37] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -6725,7 +7040,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), NumEnums: 4, - NumMessages: 91, + NumMessages: 97, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index f4e5b8e4d..3fee9eca8 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -64,6 +64,17 @@ service DaemonService { rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {} + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + rpc StartCapture(StartCaptureRequest) returns (stream CapturePacket) {} + + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + rpc StartBundleCapture(StartBundleCaptureRequest) returns (StartBundleCaptureResponse) {} + + // StopBundleCapture stops the running bundle capture. Idempotent. + rpc StopBundleCapture(StopBundleCaptureRequest) returns (StopBundleCaptureResponse) {} + rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {} rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {} @@ -832,3 +843,26 @@ message ExposeServiceReady { string domain = 3; bool port_auto_assigned = 4; } + +message StartCaptureRequest { + bool text_output = 1; + uint32 snap_len = 2; + google.protobuf.Duration duration = 3; + string filter_expr = 4; + bool verbose = 5; + bool ascii = 6; +} + +message CapturePacket { + bytes data = 1; +} + +message StartBundleCaptureRequest { + // timeout auto-stops the capture after this duration. + // Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum. + google.protobuf.Duration timeout = 1; +} + +message StartBundleCaptureResponse {} +message StopBundleCaptureRequest {} +message StopBundleCaptureResponse {} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 026ee2361..66a8efcc3 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -37,6 +37,9 @@ const ( DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState" DaemonService_SetSyncResponsePersistence_FullMethodName = "/daemon.DaemonService/SetSyncResponsePersistence" DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket" + DaemonService_StartCapture_FullMethodName = "/daemon.DaemonService/StartCapture" + DaemonService_StartBundleCapture_FullMethodName = "/daemon.DaemonService/StartBundleCapture" + DaemonService_StopBundleCapture_FullMethodName = "/daemon.DaemonService/StopBundleCapture" DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents" DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents" DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile" @@ -96,6 +99,14 @@ type DaemonServiceClient interface { // SetSyncResponsePersistence enables or disables sync response persistence SetSyncResponsePersistence(ctx context.Context, in *SetSyncResponsePersistenceRequest, opts ...grpc.CallOption) (*SetSyncResponsePersistenceResponse, error) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[CapturePacket], error) + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) + // StopBundleCapture stops the running bundle capture. Idempotent. + StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error) @@ -313,9 +324,48 @@ func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRe return out, nil } +func (c *daemonServiceClient) StartCapture(ctx context.Context, in *StartCaptureRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[CapturePacket], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_StartCapture_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[StartCaptureRequest, CapturePacket]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_StartCaptureClient = grpc.ServerStreamingClient[CapturePacket] + +func (c *daemonServiceClient) StartBundleCapture(ctx context.Context, in *StartBundleCaptureRequest, opts ...grpc.CallOption) (*StartBundleCaptureResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(StartBundleCaptureResponse) + err := c.cc.Invoke(ctx, DaemonService_StartBundleCapture_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) StopBundleCapture(ctx context.Context, in *StopBundleCaptureRequest, opts ...grpc.CallOption) (*StopBundleCaptureResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(StopBundleCaptureResponse) + err := c.cc.Invoke(ctx, DaemonService_StopBundleCapture_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_SubscribeEvents_FullMethodName, cOpts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], DaemonService_SubscribeEvents_FullMethodName, cOpts...) if err != nil { return nil, err } @@ -494,7 +544,7 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], DaemonService_ExposeService_FullMethodName, cOpts...) + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[2], DaemonService_ExposeService_FullMethodName, cOpts...) if err != nil { return nil, err } @@ -550,6 +600,14 @@ type DaemonServiceServer interface { // SetSyncResponsePersistence enables or disables sync response persistence SetSyncResponsePersistence(context.Context, *SetSyncResponsePersistenceRequest) (*SetSyncResponsePersistenceResponse, error) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) + // StartCapture begins streaming packet capture on the WireGuard interface. + // Requires --enable-capture set at service install/reconfigure time. + StartCapture(*StartCaptureRequest, grpc.ServerStreamingServer[CapturePacket]) error + // StartBundleCapture begins capturing packets to a server-side temp file + // for inclusion in the next debug bundle. Auto-stops after the given timeout. + StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) + // StopBundleCapture stops the running bundle capture. Idempotent. + StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error) @@ -641,6 +699,15 @@ func (UnimplementedDaemonServiceServer) SetSyncResponsePersistence(context.Conte func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { return nil, status.Error(codes.Unimplemented, "method TracePacket not implemented") } +func (UnimplementedDaemonServiceServer) StartCapture(*StartCaptureRequest, grpc.ServerStreamingServer[CapturePacket]) error { + return status.Error(codes.Unimplemented, "method StartCapture not implemented") +} +func (UnimplementedDaemonServiceServer) StartBundleCapture(context.Context, *StartBundleCaptureRequest) (*StartBundleCaptureResponse, error) { + return nil, status.Error(codes.Unimplemented, "method StartBundleCapture not implemented") +} +func (UnimplementedDaemonServiceServer) StopBundleCapture(context.Context, *StopBundleCaptureRequest) (*StopBundleCaptureResponse, error) { + return nil, status.Error(codes.Unimplemented, "method StopBundleCapture not implemented") +} func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error { return status.Error(codes.Unimplemented, "method SubscribeEvents not implemented") } @@ -1040,6 +1107,53 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_StartCapture_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(StartCaptureRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(DaemonServiceServer).StartCapture(m, &grpc.GenericServerStream[StartCaptureRequest, CapturePacket]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type DaemonService_StartCaptureServer = grpc.ServerStreamingServer[CapturePacket] + +func _DaemonService_StartBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StartBundleCaptureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StartBundleCapture(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: DaemonService_StartBundleCapture_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StartBundleCapture(ctx, req.(*StartBundleCaptureRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_StopBundleCapture_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StopBundleCaptureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).StopBundleCapture(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: DaemonService_StopBundleCapture_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).StopBundleCapture(ctx, req.(*StopBundleCaptureRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(SubscribeRequest) if err := stream.RecvMsg(m); err != nil { @@ -1429,6 +1543,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "TracePacket", Handler: _DaemonService_TracePacket_Handler, }, + { + MethodName: "StartBundleCapture", + Handler: _DaemonService_StartBundleCapture_Handler, + }, + { + MethodName: "StopBundleCapture", + Handler: _DaemonService_StopBundleCapture_Handler, + }, { MethodName: "GetEvents", Handler: _DaemonService_GetEvents_Handler, @@ -1495,6 +1617,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ }, }, Streams: []grpc.StreamDesc{ + { + StreamName: "StartCapture", + Handler: _DaemonService_StartCapture_Handler, + ServerStreams: true, + }, { StreamName: "SubscribeEvents", Handler: _DaemonService_SubscribeEvents_Handler, diff --git a/client/server/capture.go b/client/server/capture.go new file mode 100644 index 000000000..308c00338 --- /dev/null +++ b/client/server/capture.go @@ -0,0 +1,365 @@ +package server + +import ( + "context" + "io" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" +) + +const maxBundleCaptureDuration = 10 * time.Minute + +// bundleCapture holds the state of an in-progress capture destined for the +// debug bundle. The lifecycle is: +// +// StartBundleCapture → capture running, writing to temp file +// StopBundleCapture → capture stopped, temp file available +// DebugBundle → temp file included in zip, then cleaned up +type bundleCapture struct { + mu sync.Mutex + sess *capture.Session + file *os.File + engine *internal.Engine + cancel context.CancelFunc + stopped bool +} + +// stop halts the capture session and closes the pcap writer. Idempotent. +func (bc *bundleCapture) stop() { + bc.mu.Lock() + defer bc.mu.Unlock() + + if bc.stopped { + return + } + bc.stopped = true + + if bc.cancel != nil { + bc.cancel() + } + if bc.sess != nil { + bc.sess.Stop() + } +} + +// path returns the temp file path, or "" if no file exists. +func (bc *bundleCapture) path() string { + if bc.file == nil { + return "" + } + return bc.file.Name() +} + +// cleanup removes the temp file. +func (bc *bundleCapture) cleanup() { + if bc.file == nil { + return + } + name := bc.file.Name() + if err := bc.file.Close(); err != nil { + log.Debugf("close bundle capture file: %v", err) + } + if err := os.Remove(name); err != nil && !os.IsNotExist(err) { + log.Debugf("remove bundle capture file: %v", err) + } + bc.file = nil +} + +// StartCapture streams a pcap or text packet capture over gRPC. +// Gated by the --enable-capture service flag. +func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.DaemonService_StartCaptureServer) error { + if !s.captureEnabled { + return status.Error(codes.PermissionDenied, + "packet capture is disabled; reinstall or reconfigure the service with --enable-capture") + } + + if d := req.GetDuration(); d != nil && d.AsDuration() < 0 { + return status.Error(codes.InvalidArgument, "duration must not be negative") + } + + matcher, err := parseCaptureFilter(req) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) + } + + pr, pw := io.Pipe() + + opts := capture.Options{ + Matcher: matcher, + SnapLen: req.GetSnapLen(), + Verbose: req.GetVerbose(), + ASCII: req.GetAscii(), + } + if req.GetTextOutput() { + opts.TextOutput = pw + } else { + opts.Output = pw + } + + sess, err := capture.NewSession(opts) + if err != nil { + pw.Close() + return status.Errorf(codes.Internal, "create capture session: %v", err) + } + + engine, err := s.claimCapture(sess) + if err != nil { + sess.Stop() + pw.Close() + return err + } + + if err := engine.SetCapture(sess); err != nil { + s.releaseCapture(sess) + sess.Stop() + pw.Close() + return status.Errorf(codes.Internal, "set capture: %v", err) + } + + // Send an empty initial message to signal that the capture was accepted. + // The client waits for this before printing the banner, so it must arrive + // before any packet data. + if err := stream.Send(&proto.CapturePacket{}); err != nil { + s.clearCaptureIfOwner(sess, engine) + sess.Stop() + pw.Close() + return status.Errorf(codes.Internal, "send initial message: %v", err) + } + + ctx := stream.Context() + if d := req.GetDuration(); d != nil { + if dur := d.AsDuration(); dur > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, dur) + defer cancel() + } + } + + go func() { + <-ctx.Done() + s.clearCaptureIfOwner(sess, engine) + sess.Stop() + pw.Close() + }() + defer pr.Close() + + log.Infof("packet capture started (text=%v, expr=%q)", req.GetTextOutput(), req.GetFilterExpr()) + defer func() { + stats := sess.Stats() + log.Infof("packet capture stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped) + }() + + return streamToGRPC(pr, stream) +} + +func streamToGRPC(r io.Reader, stream proto.DaemonService_StartCaptureServer) error { + buf := make([]byte, 32*1024) + for { + n, readErr := r.Read(buf) + if n > 0 { + if err := stream.Send(&proto.CapturePacket{Data: buf[:n]}); err != nil { + log.Debugf("capture stream send: %v", err) + return nil //nolint:nilerr // client disconnected + } + } + if readErr != nil { + return nil //nolint:nilerr // pipe closed, capture stopped normally + } + } +} + +// StartBundleCapture begins capturing packets to a server-side temp file for +// inclusion in the next debug bundle. Not gated by --enable-capture since the +// output stays on the server (same trust level as CPU profiling). +// +// A timeout auto-stops the capture as a safety net if StopBundleCapture is +// never called (e.g. CLI crash). +func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCaptureRequest) (*proto.StartBundleCaptureResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.stopBundleCaptureLocked() + s.cleanupBundleCapture() + + if s.activeCapture != nil { + return nil, status.Error(codes.FailedPrecondition, "another capture is already running") + } + + engine, err := s.getCaptureEngineLocked() + if err != nil { + // Not fatal: kernel mode or not connected. Log and return success + // so the debug bundle still generates without capture data. + log.Warnf("packet capture unavailable, skipping: %v", err) + return &proto.StartBundleCaptureResponse{}, nil + } + + timeout := req.GetTimeout().AsDuration() + if timeout <= 0 || timeout > maxBundleCaptureDuration { + timeout = maxBundleCaptureDuration + } + + f, err := os.CreateTemp("", "netbird.capture.*.pcap") + if err != nil { + return nil, status.Errorf(codes.Internal, "create temp file: %v", err) + } + + sess, err := capture.NewSession(capture.Options{Output: f}) + if err != nil { + f.Close() + os.Remove(f.Name()) + return nil, status.Errorf(codes.Internal, "create capture session: %v", err) + } + + if err := engine.SetCapture(sess); err != nil { + sess.Stop() + f.Close() + os.Remove(f.Name()) + log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err) + return &proto.StartBundleCaptureResponse{}, nil + } + s.activeCapture = sess + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + bc := &bundleCapture{ + sess: sess, + file: f, + engine: engine, + cancel: cancel, + } + + s.bundleCapture = bc + + go func() { + <-ctx.Done() + s.mutex.Lock() + if s.bundleCapture == bc { + s.stopBundleCaptureLocked() + } else { + bc.stop() + } + s.mutex.Unlock() + log.Infof("bundle capture auto-stopped after timeout") + }() + log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name()) + + return &proto.StartBundleCaptureResponse{}, nil +} + +// StopBundleCapture stops the running bundle capture. Idempotent. +func (s *Server) StopBundleCapture(_ context.Context, _ *proto.StopBundleCaptureRequest) (*proto.StopBundleCaptureResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.stopBundleCaptureLocked() + return &proto.StopBundleCaptureResponse{}, nil +} + +// stopBundleCaptureLocked stops the bundle capture if running. Must hold s.mutex. +func (s *Server) stopBundleCaptureLocked() { + if s.bundleCapture == nil { + return + } + bc := s.bundleCapture + if bc.engine != nil && s.activeCapture == bc.sess { + if err := bc.engine.SetCapture(nil); err != nil { + log.Debugf("clear bundle capture: %v", err) + } + s.activeCapture = nil + } + bc.stop() + + stats := bc.sess.Stats() + log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped) +} + +// bundleCapturePath returns the temp file path if a capture has been taken, +// stops any running capture, and returns "". Called from DebugBundle. +// Must hold s.mutex. +func (s *Server) bundleCapturePath() string { + if s.bundleCapture == nil { + return "" + } + + s.bundleCapture.stop() + return s.bundleCapture.path() +} + +// cleanupBundleCapture removes the temp file and clears state. Must hold s.mutex. +func (s *Server) cleanupBundleCapture() { + if s.bundleCapture == nil { + return + } + s.bundleCapture.cleanup() + s.bundleCapture = nil +} + +// claimCapture reserves the engine's capture slot for sess. Returns +// FailedPrecondition if another capture is already active. +func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.activeCapture != nil { + return nil, status.Error(codes.FailedPrecondition, "another capture is already running") + } + engine, err := s.getCaptureEngineLocked() + if err != nil { + return nil, err + } + s.activeCapture = sess + return engine, nil +} + +// releaseCapture clears the active-capture owner if it still matches sess. +func (s *Server) releaseCapture(sess *capture.Session) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.activeCapture == sess { + s.activeCapture = nil + } +} + +// clearCaptureIfOwner clears engine's capture slot only if sess still owns it. +func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Engine) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.activeCapture != sess { + return + } + if err := engine.SetCapture(nil); err != nil { + log.Debugf("clear capture: %v", err) + } + s.activeCapture = nil +} + +func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) { + if s.connectClient == nil { + return nil, status.Error(codes.FailedPrecondition, "client not connected") + } + engine := s.connectClient.Engine() + if engine == nil { + return nil, status.Error(codes.FailedPrecondition, "engine not initialized") + } + return engine, nil +} + +// parseCaptureFilter returns a Matcher from the request. +// Returns nil (match all) when no filter expression is set. +func parseCaptureFilter(req *proto.StartCaptureRequest) (capture.Matcher, error) { + expr := req.GetFilterExpr() + if expr == "" { + return nil, nil //nolint:nilnil // nil Matcher means "match all" + } + return capture.ParseFilter(expr) +} diff --git a/client/server/debug.go b/client/server/debug.go index 81708e576..33247db5f 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -43,7 +43,9 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( }() } - // Prepare refresh callback for health probes + capturePath := s.bundleCapturePath() + defer s.cleanupBundleCapture() + var refreshStatus func() if s.connectClient != nil { engine := s.connectClient.Engine() @@ -62,6 +64,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( SyncResponse: syncResponse, LogPath: s.logFile, CPUProfile: cpuProfileData, + CapturePath: capturePath, RefreshStatus: refreshStatus, ClientMetrics: clientMetrics, }, diff --git a/client/server/server.go b/client/server/server.go index e70b83bf8..648ffa8ce 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" "github.com/netbirdio/netbird/version" ) @@ -89,7 +90,11 @@ type Server struct { profileManager *profilemanager.ServiceManager profilesDisabled bool updateSettingsDisabled bool - networksDisabled bool + captureEnabled bool + bundleCapture *bundleCapture + // activeCapture is the session currently installed on the engine; guarded by s.mutex. + activeCapture *capture.Session + networksDisabled bool sleepHandler *sleephandler.SleepHandler @@ -106,7 +111,7 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server { +func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool, networksDisabled bool) *Server { s := &Server{ rootCtx: ctx, logFile: logFile, @@ -115,6 +120,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profileManager: profilemanager.NewServiceManager(configFile), profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, + captureEnabled: captureEnabled, networksDisabled: networksDisabled, jwtCache: newJWTCache(), } diff --git a/client/server/server_test.go b/client/server/server_test.go index 54ad47e55..641cd85fe 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -104,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "debug", "", false, false, false) + s := New(ctx, "debug", "", false, false, false, false) s.config = config @@ -165,7 +165,7 @@ func TestServer_Up(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) err = s.Start() require.NoError(t, err) @@ -235,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) err = s.Start() require.NoError(t, err) diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 7f6847c43..b90b5653d 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.NoError(t, err) ctx := context.Background() - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) rosenpassEnabled := true rosenpassPermissive := true diff --git a/client/ui/debug.go b/client/ui/debug.go index 4ebe4d675..cf5ac1a75 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -16,6 +16,7 @@ import ( "fyne.io/fyne/v2/widget" log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" + "google.golang.org/protobuf/types/known/durationpb" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/proto" @@ -38,6 +39,7 @@ type debugCollectionParams struct { upload bool uploadURL string enablePersistence bool + capture bool } // UI components for progress tracking @@ -51,25 +53,58 @@ type progressUI struct { func (s *serviceClient) showDebugUI() { w := s.app.NewWindow("NetBird Debug") w.SetOnClosed(s.cancel) - w.Resize(fyne.NewSize(600, 500)) w.SetFixedSize(true) anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil) systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil) systemInfoCheck.SetChecked(true) + captureCheck := widget.NewCheck("Include packet capture", nil) uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil) uploadCheck.SetChecked(true) - uploadURLLabel := widget.NewLabel("Debug upload URL:") + uploadURLContainer, uploadURL := s.buildUploadSection(uploadCheck) + + debugModeContainer, runForDurationCheck, durationInput, noteLabel := s.buildDurationSection() + + statusLabel := widget.NewLabel("") + statusLabel.Hide() + progressBar := widget.NewProgressBar() + progressBar.Hide() + createButton := widget.NewButton("Create Debug Bundle", nil) + + uiControls := []fyne.Disableable{ + anonymizeCheck, systemInfoCheck, captureCheck, + uploadCheck, uploadURL, runForDurationCheck, durationInput, createButton, + } + + createButton.OnTapped = s.getCreateHandler( + statusLabel, progressBar, uploadCheck, uploadURL, + anonymizeCheck, systemInfoCheck, captureCheck, + runForDurationCheck, durationInput, uiControls, w, + ) + + content := container.NewVBox( + widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), + widget.NewLabel(""), + anonymizeCheck, systemInfoCheck, captureCheck, + uploadCheck, uploadURLContainer, + widget.NewLabel(""), + debugModeContainer, noteLabel, + widget.NewLabel(""), + statusLabel, progressBar, createButton, + ) + + w.SetContent(container.NewPadded(content)) + w.Show() +} + +func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Container, *widget.Entry) { uploadURL := widget.NewEntry() uploadURL.SetText(uptypes.DefaultBundleURL) uploadURL.SetPlaceHolder("Enter upload URL") - uploadURLContainer := container.NewVBox( - uploadURLLabel, - uploadURL, - ) + uploadURLContainer := container.NewVBox(widget.NewLabel("Debug upload URL:"), uploadURL) uploadCheck.OnChanged = func(checked bool) { if checked { @@ -78,13 +113,14 @@ func (s *serviceClient) showDebugUI() { uploadURLContainer.Hide() } } + return uploadURLContainer, uploadURL +} - debugModeContainer := container.NewHBox() +func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check, *widget.Entry, *widget.Label) { runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil) runForDurationCheck.SetChecked(true) forLabel := widget.NewLabel("for") - durationInput := widget.NewEntry() durationInput.SetText("1") minutesLabel := widget.NewLabel("minute") @@ -108,63 +144,8 @@ func (s *serviceClient) showDebugUI() { } } - debugModeContainer.Add(runForDurationCheck) - debugModeContainer.Add(forLabel) - debugModeContainer.Add(durationInput) - debugModeContainer.Add(minutesLabel) - - statusLabel := widget.NewLabel("") - statusLabel.Hide() - - progressBar := widget.NewProgressBar() - progressBar.Hide() - - createButton := widget.NewButton("Create Debug Bundle", nil) - - // UI controls that should be disabled during debug collection - uiControls := []fyne.Disableable{ - anonymizeCheck, - systemInfoCheck, - uploadCheck, - uploadURL, - runForDurationCheck, - durationInput, - createButton, - } - - createButton.OnTapped = s.getCreateHandler( - statusLabel, - progressBar, - uploadCheck, - uploadURL, - anonymizeCheck, - systemInfoCheck, - runForDurationCheck, - durationInput, - uiControls, - w, - ) - - content := container.NewVBox( - widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"), - widget.NewLabel(""), - anonymizeCheck, - systemInfoCheck, - uploadCheck, - uploadURLContainer, - widget.NewLabel(""), - debugModeContainer, - noteLabel, - widget.NewLabel(""), - statusLabel, - progressBar, - createButton, - ) - - paddedContent := container.NewPadded(content) - w.SetContent(paddedContent) - - w.Show() + modeContainer := container.NewHBox(runForDurationCheck, forLabel, durationInput, minutesLabel) + return modeContainer, runForDurationCheck, durationInput, noteLabel } func validateMinute(s string, minutesLabel *widget.Label) error { @@ -200,6 +181,7 @@ func (s *serviceClient) getCreateHandler( uploadURL *widget.Entry, anonymizeCheck *widget.Check, systemInfoCheck *widget.Check, + captureCheck *widget.Check, runForDurationCheck *widget.Check, duration *widget.Entry, uiControls []fyne.Disableable, @@ -222,6 +204,7 @@ func (s *serviceClient) getCreateHandler( params := &debugCollectionParams{ anonymize: anonymizeCheck.Checked, systemInfo: systemInfoCheck.Checked, + capture: captureCheck.Checked, upload: uploadCheck.Checked, uploadURL: url, enablePersistence: true, @@ -253,10 +236,7 @@ func (s *serviceClient) getCreateHandler( statusLabel.SetText("Creating debug bundle...") go s.handleDebugCreation( - anonymizeCheck.Checked, - systemInfoCheck.Checked, - uploadCheck.Checked, - url, + params, statusLabel, uiControls, w, @@ -371,7 +351,7 @@ func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time func (s *serviceClient) configureServiceForDebug( conn proto.DaemonServiceClient, state *debugInitialState, - enablePersistence bool, + params *debugCollectionParams, ) { if state.wasDown { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { @@ -397,7 +377,7 @@ func (s *serviceClient) configureServiceForDebug( time.Sleep(time.Second) } - if enablePersistence { + if params.enablePersistence { if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{ Enabled: true, }); err != nil { @@ -417,6 +397,26 @@ func (s *serviceClient) configureServiceForDebug( if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { log.Warnf("failed to start CPU profiling: %v", err) } + + s.startBundleCaptureIfEnabled(conn, params) +} + +func (s *serviceClient) startBundleCaptureIfEnabled(conn proto.DaemonServiceClient, params *debugCollectionParams) { + if !params.capture { + return + } + + const maxCapture = 10 * time.Minute + timeout := params.duration + 30*time.Second + if timeout > maxCapture { + timeout = maxCapture + log.Warnf("packet capture clamped to %s (server maximum)", maxCapture) + } + if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(timeout), + }); err != nil { + log.Warnf("failed to start bundle capture: %v", err) + } } func (s *serviceClient) collectDebugData( @@ -430,7 +430,7 @@ func (s *serviceClient) collectDebugData( var wg sync.WaitGroup startProgressTracker(ctx, &wg, params.duration, progress) - s.configureServiceForDebug(conn, state, params.enablePersistence) + s.configureServiceForDebug(conn, state, params) wg.Wait() progress.progressBar.Hide() @@ -440,6 +440,14 @@ func (s *serviceClient) collectDebugData( log.Warnf("failed to stop CPU profiling: %v", err) } + if params.capture { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + log.Warnf("failed to stop bundle capture: %v", err) + } + } + return nil } @@ -520,18 +528,37 @@ func handleError(progress *progressUI, errMsg string) { } func (s *serviceClient) handleDebugCreation( - anonymize bool, - systemInfo bool, - upload bool, - uploadURL string, + params *debugCollectionParams, statusLabel *widget.Label, uiControls []fyne.Disableable, w fyne.Window, ) { - log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...", - anonymize, systemInfo, upload) + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + log.Errorf("Failed to get client for debug: %v", err) + statusLabel.SetText(fmt.Sprintf("Error: %v", err)) + enableUIControls(uiControls) + return + } - resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL) + if params.capture { + if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{ + Timeout: durationpb.New(30 * time.Second), + }); err != nil { + log.Warnf("failed to start bundle capture: %v", err) + } else { + defer func() { + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil { + log.Warnf("failed to stop bundle capture: %v", err) + } + }() + time.Sleep(2 * time.Second) + } + } + + resp, err := s.createDebugBundle(params.anonymize, params.systemInfo, params.uploadURL) if err != nil { log.Errorf("Failed to create debug bundle: %v", err) statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err)) @@ -543,7 +570,7 @@ func (s *serviceClient) handleDebugCreation( uploadFailureReason := resp.GetUploadFailureReason() uploadedKey := resp.GetUploadedKey() - if upload { + if params.upload { if uploadFailureReason != "" { showUploadFailedDialog(w, localPath, uploadFailureReason) } else { diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index d8e50ab6d..cb512f132 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -5,6 +5,7 @@ package main import ( "context" "fmt" + "sync" "syscall/js" "time" @@ -14,6 +15,7 @@ import ( netbird "github.com/netbirdio/netbird/client/embed" sshdetection "github.com/netbirdio/netbird/client/ssh/detection" nbstatus "github.com/netbirdio/netbird/client/status" + wasmcapture "github.com/netbirdio/netbird/client/wasm/internal/capture" "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" @@ -459,6 +461,95 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func { }) } +// createStartCaptureMethod creates the programmable packet capture method. +// Returns a JS interface with onpacket callback and stop() method. +// +// Usage from JavaScript: +// +// const cap = await client.startCapture({ filter: "tcp port 443", verbose: true }) +// cap.onpacket = (line) => console.log(line) +// const stats = cap.stop() +func createStartCaptureMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + var opts js.Value + if len(args) > 0 { + opts = args[0] + } + + return createPromise(func(resolve, reject js.Value) { + iface, err := wasmcapture.Start(client, opts) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err))) + return + } + resolve.Invoke(iface) + }) + }) +} + +// captureMethods returns capture() and stopCapture() that share state for +// the console-log shortcut. capture() logs packets to the browser console +// and stopCapture() ends it, like Ctrl+C on the CLI. +// +// Usage from browser devtools console: +// +// await client.capture() // capture all packets +// await client.capture("tcp") // capture with filter +// await client.capture({filter: "host 10.0.0.1", verbose: true}) +// client.stopCapture() // stop and print stats +func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) { + var mu sync.Mutex + var active *wasmcapture.Handle + + startFn = js.FuncOf(func(_ js.Value, args []js.Value) any { + var opts js.Value + if len(args) > 0 { + opts = args[0] + } + + return createPromise(func(resolve, reject js.Value) { + mu.Lock() + defer mu.Unlock() + + if active != nil { + active.Stop() + active = nil + } + + h, err := wasmcapture.StartConsole(client, opts) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err))) + return + } + active = h + + console := js.Global().Get("console") + console.Call("log", "[capture] started, call client.stopCapture() to stop") + resolve.Invoke(js.Undefined()) + }) + }) + + stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + mu.Lock() + defer mu.Unlock() + + if active == nil { + js.Global().Get("console").Call("log", "[capture] no active capture") + return js.Undefined() + } + + stats := active.Stop() + active = nil + + console := js.Global().Get("console") + console.Call("log", fmt.Sprintf("[capture] stopped: %d packets, %d bytes, %d dropped", + stats.Packets, stats.Bytes, stats.Dropped)) + return js.Undefined() + }) + + return startFn, stopFn +} + // createPromise is a helper to create JavaScript promises func createPromise(handler func(resolve, reject js.Value)) js.Value { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { @@ -521,6 +612,11 @@ func createClientObject(client *netbird.Client) js.Value { obj["statusDetail"] = createStatusDetailMethod(client) obj["getSyncResponse"] = createGetSyncResponseMethod(client) obj["setLogLevel"] = createSetLogLevelMethod(client) + obj["startCapture"] = createStartCaptureMethod(client) + + capStart, capStop := captureMethods(client) + obj["capture"] = capStart + obj["stopCapture"] = capStop return js.ValueOf(obj) } diff --git a/client/wasm/internal/capture/capture.go b/client/wasm/internal/capture/capture.go new file mode 100644 index 000000000..53e43c45e --- /dev/null +++ b/client/wasm/internal/capture/capture.go @@ -0,0 +1,176 @@ +//go:build js + +// Package capture bridges the util/capture package to JavaScript via syscall/js. +package capture + +import ( + "strings" + "sync" + "syscall/js" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +// Handle holds a running capture session so it can be stopped later. +type Handle struct { + cs *netbird.CaptureSession + stopFn js.Func + stopped bool +} + +// Stop ends the capture and returns stats. +func (h *Handle) Stop() netbird.CaptureStats { + if h.stopped { + return h.cs.Stats() + } + h.stopped = true + h.stopFn.Release() + + h.cs.Stop() + return h.cs.Stats() +} + +func statsToJS(s netbird.CaptureStats) js.Value { + obj := js.Global().Get("Object").Call("create", js.Null()) + obj.Set("packets", js.ValueOf(s.Packets)) + obj.Set("bytes", js.ValueOf(s.Bytes)) + obj.Set("dropped", js.ValueOf(s.Dropped)) + return obj +} + +// parseOpts extracts filter/verbose/ascii from a JS options value. +func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) { + if jsOpts.IsNull() || jsOpts.IsUndefined() { + return + } + if jsOpts.Type() == js.TypeString { + filter = jsOpts.String() + return + } + if jsOpts.Type() != js.TypeObject { + return + } + if f := jsOpts.Get("filter"); !f.IsUndefined() && !f.IsNull() { + filter = f.String() + } + if v := jsOpts.Get("verbose"); !v.IsUndefined() { + verbose = v.Truthy() + } + if a := jsOpts.Get("ascii"); !a.IsUndefined() { + ascii = a.Truthy() + } + return +} + +// Start creates a capture session and returns a JS interface for streaming text +// output. The returned object exposes: +// +// onpacket(callback) - set callback(string) for each text line +// stop() - stop capture and return stats { packets, bytes, dropped } +// +// Options: { filter: string, verbose: bool, ascii: bool } or just a filter string. +func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) { + filter, verbose, ascii := parseOpts(jsOpts) + + cb := &jsCallbackWriter{} + + cs, err := client.StartCapture(netbird.CaptureOptions{ + TextOutput: cb, + Filter: filter, + Verbose: verbose, + ASCII: ascii, + }) + if err != nil { + return js.Undefined(), err + } + + handle := &Handle{cs: cs} + + iface := js.Global().Get("Object").Call("create", js.Null()) + handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + return statsToJS(handle.Stop()) + }) + iface.Set("stop", handle.stopFn) + iface.Set("onpacket", js.Undefined()) + cb.setInterface(iface) + + return iface, nil +} + +// StartConsole starts a capture that logs every packet line to console.log. +// Returns a Handle so the caller can stop it later. +func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) { + filter, verbose, ascii := parseOpts(jsOpts) + + cb := &jsCallbackWriter{} + + cs, err := client.StartCapture(netbird.CaptureOptions{ + TextOutput: cb, + Filter: filter, + Verbose: verbose, + ASCII: ascii, + }) + if err != nil { + return nil, err + } + + handle := &Handle{cs: cs} + handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any { + return statsToJS(handle.Stop()) + }) + + iface := js.Global().Get("Object").Call("create", js.Null()) + console := js.Global().Get("console") + iface.Set("onpacket", console.Get("log").Call("bind", console, js.ValueOf("[capture]"))) + cb.setInterface(iface) + + return handle, nil +} + +// jsCallbackWriter is an io.Writer that buffers text until a newline, then +// invokes the JS onpacket callback with each complete line. +type jsCallbackWriter struct { + mu sync.Mutex + iface js.Value + buf strings.Builder +} + +func (w *jsCallbackWriter) setInterface(iface js.Value) { + w.mu.Lock() + defer w.mu.Unlock() + w.iface = iface +} + +func (w *jsCallbackWriter) Write(p []byte) (int, error) { + w.mu.Lock() + w.buf.Write(p) + + var lines []string + for { + str := w.buf.String() + idx := strings.IndexByte(str, '\n') + if idx < 0 { + break + } + lines = append(lines, str[:idx]) + w.buf.Reset() + if idx+1 < len(str) { + w.buf.WriteString(str[idx+1:]) + } + } + + iface := w.iface + w.mu.Unlock() + + if iface.IsUndefined() { + return len(p), nil + } + cb := iface.Get("onpacket") + if cb.IsUndefined() || cb.IsNull() { + return len(p), nil + } + for _, line := range lines { + cb.Invoke(js.ValueOf(line)) + } + return len(p), nil +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 756c42421..e259856e3 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1761,7 +1761,7 @@ func hasNilField(x interface{}) error { if f := rv.Field(i); f.IsValid() { k := f.Kind() switch k { - case reflect.Ptr: + case reflect.Pointer: if f.IsNil() { return fmt.Errorf("field %s is nil", f.String()) } diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 59f7a6b65..1b1664490 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -2,7 +2,12 @@ package cmd import ( "fmt" + "os" + "os/signal" + "path/filepath" "strconv" + "strings" + "syscall" "github.com/spf13/cobra" @@ -99,6 +104,27 @@ var debugStopCmd = &cobra.Command{ SilenceUsage: true, } +var debugCaptureCmd = &cobra.Command{ + Use: "capture [filter expression]", + Short: "Capture packets on a client's WireGuard interface", + Long: `Captures decrypted packets flowing through a client's WireGuard interface. + +Default output is human-readable text. Use --pcap or --output for pcap binary. +Filter arguments after the account ID use BPF-like syntax. + +Examples: + netbird-proxy debug capture + netbird-proxy debug capture --duration 1m host 10.0.0.1 + netbird-proxy debug capture host 10.0.0.1 and tcp port 443 + netbird-proxy debug capture not port 22 + netbird-proxy debug capture -o capture.pcap + netbird-proxy debug capture --pcap | tcpdump -r - -n + netbird-proxy debug capture --pcap | tshark -r -`, + Args: cobra.MinimumNArgs(1), + RunE: runDebugCapture, + SilenceUsage: true, +} + func init() { debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address") debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format") @@ -110,6 +136,12 @@ func init() { debugPingCmd.Flags().StringVar(&pingTimeout, "timeout", "", "Ping timeout (e.g., 10s)") + debugCaptureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = server default)") + debugCaptureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)") + debugCaptureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length (text mode)") + debugCaptureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (text mode)") + debugCaptureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout") + debugCmd.AddCommand(debugHealthCmd) debugCmd.AddCommand(debugClientsCmd) debugCmd.AddCommand(debugStatusCmd) @@ -119,6 +151,7 @@ func init() { debugCmd.AddCommand(debugLogCmd) debugCmd.AddCommand(debugStartCmd) debugCmd.AddCommand(debugStopCmd) + debugCmd.AddCommand(debugCaptureCmd) rootCmd.AddCommand(debugCmd) } @@ -171,3 +204,84 @@ func runDebugStart(cmd *cobra.Command, args []string) error { func runDebugStop(cmd *cobra.Command, args []string) error { return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) } + +func runDebugCapture(cmd *cobra.Command, args []string) error { + duration, _ := cmd.Flags().GetDuration("duration") + forcePcap, _ := cmd.Flags().GetBool("pcap") + verbose, _ := cmd.Flags().GetBool("verbose") + ascii, _ := cmd.Flags().GetBool("ascii") + outPath, _ := cmd.Flags().GetString("output") + + // Default to text. Use pcap when --pcap is set or --output is given. + wantText := !forcePcap && outPath == "" + + var filterExpr string + if len(args) > 1 { + filterExpr = strings.Join(args[1:], " ") + } + + ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + out, cleanup, err := captureOutputWriter(cmd, outPath) + if err != nil { + return err + } + defer cleanup() + + if wantText { + cmd.PrintErrln("Capturing packets... Press Ctrl+C to stop.") + } else { + cmd.PrintErrln("Capturing packets (pcap)... Press Ctrl+C to stop.") + } + + var durationStr string + if duration > 0 { + durationStr = duration.String() + } + + err = getDebugClient(cmd).Capture(ctx, debug.CaptureOptions{ + AccountID: args[0], + Duration: durationStr, + FilterExpr: filterExpr, + Text: wantText, + Verbose: verbose, + ASCII: ascii, + Output: out, + }) + if err != nil { + return err + } + + cmd.PrintErrln("\nCapture finished.") + return nil +} + +// captureOutputWriter returns the writer and cleanup function for capture output. +func captureOutputWriter(cmd *cobra.Command, outPath string) (out *os.File, cleanup func(), err error) { + if outPath != "" { + f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp") + if err != nil { + return nil, nil, fmt.Errorf("create output file: %w", err) + } + tmpPath := f.Name() + return f, func() { + if err := f.Close(); err != nil { + cmd.PrintErrf("close output file: %v\n", err) + } + if fi, err := os.Stat(tmpPath); err == nil && fi.Size() > 0 { + if err := os.Rename(tmpPath, outPath); err != nil { + cmd.PrintErrf("rename output file: %v\n", err) + } else { + cmd.PrintErrf("Wrote %s\n", outPath) + } + } else { + os.Remove(tmpPath) + } + }, nil + } + + return os.Stdout, func() { + // no cleanup needed for stdout + }, nil +} diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 01b0bc8e6..e01149522 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -310,6 +310,76 @@ func (c *Client) printError(data map[string]any) { } } +// CaptureOptions configures a capture request. +type CaptureOptions struct { + AccountID string + Duration string + FilterExpr string + Text bool + Verbose bool + ASCII bool + Output io.Writer +} + +// Capture streams a packet capture from the debug endpoint. The response body +// (pcap or text) is written directly to opts.Output until the server closes the +// connection or the context is cancelled. +func (c *Client) Capture(ctx context.Context, opts CaptureOptions) error { + if opts.AccountID == "" { + return fmt.Errorf("account ID is required") + } + if opts.Output == nil { + return fmt.Errorf("output writer is required") + } + + params := url.Values{} + if opts.Duration != "" { + params.Set("duration", opts.Duration) + } + if opts.FilterExpr != "" { + params.Set("filter", opts.FilterExpr) + } + if opts.Text { + params.Set("format", "text") + } + if opts.Verbose { + params.Set("verbose", "true") + } + if opts.ASCII { + params.Set("ascii", "true") + } + + path := fmt.Sprintf("/debug/clients/%s/capture", url.PathEscape(opts.AccountID)) + if len(params) > 0 { + path += "?" + params.Encode() + } + + fullURL := c.baseURL + path + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + + // Use a separate client without timeout since captures stream for their full duration. + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + _, err = io.Copy(opts.Output, resp.Body) + if err != nil && ctx.Err() != nil { + return nil + } + return err +} + func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error { data, raw, err := c.fetch(ctx, path) if err != nil { diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index c507cfad9..6cd124554 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -174,6 +174,8 @@ func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, pat h.handleClientStart(w, r, accountID) case "stop": h.handleClientStop(w, r, accountID) + case "capture": + h.handleCapture(w, r, accountID) default: return false } @@ -632,6 +634,81 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou }) } +const maxCaptureDuration = 30 * time.Minute + +// handleCapture streams a pcap or text packet capture for the given client. +// +// Query params: +// +// duration: capture duration (0 or absent = max, capped at 30m) +// format: "text" for human-readable output (default: pcap) +// filter: BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443") +func (h *Handler) handleCapture(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { + client, ok := h.provider.GetClient(accountID) + if !ok { + http.Error(w, "client not found", http.StatusNotFound) + return + } + + duration := maxCaptureDuration + if durationStr := r.URL.Query().Get("duration"); durationStr != "" { + d, err := time.ParseDuration(durationStr) + if err != nil { + http.Error(w, "invalid duration: "+err.Error(), http.StatusBadRequest) + return + } + if d < 0 { + http.Error(w, "duration must not be negative", http.StatusBadRequest) + return + } + if d > 0 { + duration = min(d, maxCaptureDuration) + } + } + + filter := r.URL.Query().Get("filter") + wantText := r.URL.Query().Get("format") == "text" + verbose := r.URL.Query().Get("verbose") == "true" + ascii := r.URL.Query().Get("ascii") == "true" + + opts := nbembed.CaptureOptions{Filter: filter, Verbose: verbose, ASCII: ascii} + if wantText { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + opts.TextOutput = w + } else { + w.Header().Set("Content-Type", "application/vnd.tcpdump.pcap") + w.Header().Set("Content-Disposition", + fmt.Sprintf("attachment; filename=capture-%s.pcap", accountID)) + opts.Output = w + } + + cs, err := client.StartCapture(opts) + if err != nil { + http.Error(w, "start capture: "+err.Error(), http.StatusServiceUnavailable) + return + } + defer cs.Stop() + + // Flush headers after setup succeeds so errors above can still set status codes. + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-r.Context().Done(): + case <-timer.C: + } + + cs.Stop() + + stats := cs.Stats() + h.logger.Infof("capture for %s finished: %d packets, %d bytes, %d dropped", + accountID, stats.Packets, stats.Bytes, stats.Dropped) +} + func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) { if !wantJSON { http.Redirect(w, r, "/debug", http.StatusSeeOther) diff --git a/util/capture/afpacket_linux.go b/util/capture/afpacket_linux.go new file mode 100644 index 000000000..bf59e806a --- /dev/null +++ b/util/capture/afpacket_linux.go @@ -0,0 +1,199 @@ +package capture + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +// htons converts a uint16 from host to network (big-endian) byte order. +func htons(v uint16) uint16 { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], v) + return binary.NativeEndian.Uint16(buf[:]) +} + +// AFPacketCapture reads raw packets from a network interface using an +// AF_PACKET socket. This is the kernel-mode fallback when FilteredDevice is +// not available (kernel WireGuard). Linux only. +// +// It implements device.PacketCapture so it can be set on a Session, but it +// drives its own read loop rather than being called from FilteredDevice. +// Call Start to begin and Stop to end. +type AFPacketCapture struct { + ifaceName string + sess *Session + fd int + mu sync.Mutex + stopped chan struct{} + started atomic.Bool + closed atomic.Bool +} + +// NewAFPacketCapture creates a capture bound to the given interface. +// The session receives packets via Offer. +func NewAFPacketCapture(ifaceName string, sess *Session) *AFPacketCapture { + return &AFPacketCapture{ + ifaceName: ifaceName, + sess: sess, + fd: -1, + stopped: make(chan struct{}), + } +} + +// Start opens the AF_PACKET socket and begins reading packets. +// Packets are fed to the session via Offer. Returns immediately; +// the read loop runs in a goroutine. +func (c *AFPacketCapture) Start() error { + if c.sess == nil { + return errors.New("nil capture session") + } + if !c.started.CompareAndSwap(false, true) { + return errors.New("capture already started") + } + if c.closed.Load() { + c.started.Store(false) + return errors.New("cannot restart stopped capture") + } + + iface, err := net.InterfaceByName(c.ifaceName) + if err != nil { + c.started.Store(false) + return fmt.Errorf("interface %s: %w", c.ifaceName, err) + } + + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_DGRAM|unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC, int(htons(unix.ETH_P_ALL))) + if err != nil { + c.started.Store(false) + return fmt.Errorf("create AF_PACKET socket: %w", err) + } + + addr := &unix.SockaddrLinklayer{ + Protocol: htons(unix.ETH_P_ALL), + Ifindex: iface.Index, + } + if err := unix.Bind(fd, addr); err != nil { + unix.Close(fd) + c.started.Store(false) + return fmt.Errorf("bind to %s: %w", c.ifaceName, err) + } + + c.mu.Lock() + c.fd = fd + c.mu.Unlock() + + go c.readLoop(fd) + return nil +} + +// Stop closes the socket and waits for the read loop to exit. Idempotent. +func (c *AFPacketCapture) Stop() { + if !c.closed.CompareAndSwap(false, true) { + if c.started.Load() { + <-c.stopped + } + return + } + + c.mu.Lock() + fd := c.fd + c.fd = -1 + c.mu.Unlock() + + if fd >= 0 { + unix.Close(fd) + } + + if c.started.Load() { + <-c.stopped + } +} + +func (c *AFPacketCapture) readLoop(fd int) { + defer close(c.stopped) + + buf := make([]byte, 65536) + pollFds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}} + + for { + if c.closed.Load() { + return + } + + ok, err := c.pollOnce(pollFds) + if err != nil { + return + } + if !ok { + continue + } + + c.recvAndOffer(fd, buf) + } +} + +// pollOnce waits for data on the fd. Returns true if data is ready, false for timeout/retry. +// Returns an error to signal the loop should exit. +func (c *AFPacketCapture) pollOnce(pollFds []unix.PollFd) (bool, error) { + n, err := unix.Poll(pollFds, 200) + if err != nil { + if errors.Is(err, unix.EINTR) { + return false, nil + } + if c.closed.Load() { + return false, errors.New("closed") + } + log.Debugf("af_packet poll: %v", err) + return false, err + } + if n == 0 { + return false, nil + } + if pollFds[0].Revents&(unix.POLLERR|unix.POLLHUP|unix.POLLNVAL) != 0 { + return false, errors.New("fd error") + } + return true, nil +} + +func (c *AFPacketCapture) recvAndOffer(fd int, buf []byte) { + nr, from, err := unix.Recvfrom(fd, buf, 0) + if err != nil { + if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) { + return + } + if !c.closed.Load() { + log.Debugf("af_packet recvfrom: %v", err) + } + return + } + if nr < 1 { + return + } + + ver := buf[0] >> 4 + if ver != 4 && ver != 6 { + return + } + + // The kernel sets Pkttype on AF_PACKET sockets: + // PACKET_HOST(0) = addressed to us (inbound) + // PACKET_OUTGOING(4) = sent by us (outbound) + outbound := false + if sa, ok := from.(*unix.SockaddrLinklayer); ok { + outbound = sa.Pkttype == unix.PACKET_OUTGOING + } + c.sess.Offer(buf[:nr], outbound) +} + +// Offer satisfies device.PacketCapture but is unused: the AFPacketCapture +// drives its own read loop. This exists only so the type signature is +// compatible if someone tries to set it as a PacketCapture. +func (c *AFPacketCapture) Offer([]byte, bool) { + // unused: AFPacketCapture drives its own read loop +} diff --git a/util/capture/afpacket_stub.go b/util/capture/afpacket_stub.go new file mode 100644 index 000000000..bde368e88 --- /dev/null +++ b/util/capture/afpacket_stub.go @@ -0,0 +1,26 @@ +//go:build !linux + +package capture + +import "errors" + +// AFPacketCapture is not available on this platform. +type AFPacketCapture struct{} + +// NewAFPacketCapture returns nil on non-Linux platforms. +func NewAFPacketCapture(string, *Session) *AFPacketCapture { return nil } + +// Start returns an error on non-Linux platforms. +func (c *AFPacketCapture) Start() error { + return errors.New("AF_PACKET capture is only supported on Linux") +} + +// Stop is a no-op on non-Linux platforms. +func (c *AFPacketCapture) Stop() { + // no-op on non-Linux platforms +} + +// Offer is a no-op on non-Linux platforms. +func (c *AFPacketCapture) Offer([]byte, bool) { + // no-op on non-Linux platforms +} diff --git a/util/capture/capture.go b/util/capture/capture.go new file mode 100644 index 000000000..0d92a4311 --- /dev/null +++ b/util/capture/capture.go @@ -0,0 +1,59 @@ +// Package capture provides userspace packet capture in pcap format. +// +// It taps decrypted WireGuard packets flowing through the FilteredDevice and +// writes them as pcap (readable by tcpdump, tshark, Wireshark) or as +// human-readable one-line-per-packet text. +package capture + +import "io" + +// Direction indicates whether a packet is entering or leaving the host. +type Direction uint8 + +const ( + // Inbound is a packet arriving from the network (FilteredDevice.Write path). + Inbound Direction = iota + // Outbound is a packet leaving the host (FilteredDevice.Read path). + Outbound +) + +// String returns "IN" or "OUT". +func (d Direction) String() string { + if d == Outbound { + return "OUT" + } + return "IN" +} + +const ( + protoICMP = 1 + protoTCP = 6 + protoUDP = 17 + protoICMPv6 = 58 +) + +// Options configures a capture session. +type Options struct { + // Output receives pcap-formatted data. Nil disables pcap output. + Output io.Writer + // TextOutput receives human-readable packet summaries. Nil disables text output. + TextOutput io.Writer + // Matcher selects which packets to capture. Nil captures all. + // Use ParseFilter("host 10.0.0.1 and tcp") or &Filter{...}. + Matcher Matcher + // Verbose adds seq/ack, TTL, window, total length to text output. + Verbose bool + // ASCII dumps transport payload as printable ASCII after each packet line. + ASCII bool + // SnapLen is the maximum bytes captured per packet. 0 means 65535. + SnapLen uint32 + // BufSize is the internal channel buffer size. 0 means 256. + BufSize int +} + +// Stats reports capture session counters. +type Stats struct { + Packets int64 + Bytes int64 + Dropped int64 +} diff --git a/util/capture/filter.go b/util/capture/filter.go new file mode 100644 index 000000000..d463450b8 --- /dev/null +++ b/util/capture/filter.go @@ -0,0 +1,528 @@ +package capture + +import ( + "encoding/binary" + "fmt" + "net/netip" + "strconv" + "strings" +) + +// Matcher tests whether a raw packet should be captured. +type Matcher interface { + Match(data []byte) bool +} + +// Filter selects packets by flat AND'd criteria. Useful for structured APIs +// (query params, proto fields). Implements Matcher. +type Filter struct { + SrcIP netip.Addr + DstIP netip.Addr + Host netip.Addr + SrcPort uint16 + DstPort uint16 + Port uint16 + Proto uint8 +} + +// IsEmpty returns true if the filter has no criteria set. +func (f *Filter) IsEmpty() bool { + return !f.SrcIP.IsValid() && !f.DstIP.IsValid() && !f.Host.IsValid() && + f.SrcPort == 0 && f.DstPort == 0 && f.Port == 0 && f.Proto == 0 +} + +// Match implements Matcher. All non-zero fields must match (AND). +func (f *Filter) Match(data []byte) bool { + if f.IsEmpty() { + return true + } + info, ok := parsePacketInfo(data) + if !ok { + return false + } + if f.Host.IsValid() && info.srcIP != f.Host && info.dstIP != f.Host { + return false + } + if f.SrcIP.IsValid() && info.srcIP != f.SrcIP { + return false + } + if f.DstIP.IsValid() && info.dstIP != f.DstIP { + return false + } + if f.Proto != 0 && info.proto != f.Proto { + return false + } + if f.Port != 0 && info.srcPort != f.Port && info.dstPort != f.Port { + return false + } + if f.SrcPort != 0 && info.srcPort != f.SrcPort { + return false + } + if f.DstPort != 0 && info.dstPort != f.DstPort { + return false + } + return true +} + +// exprNode evaluates a filter condition against pre-parsed packet info. +type exprNode func(info *packetInfo) bool + +// exprMatcher wraps an expression tree. Parses the packet once, then walks the tree. +type exprMatcher struct { + root exprNode +} + +func (m *exprMatcher) Match(data []byte) bool { + info, ok := parsePacketInfo(data) + if !ok { + return false + } + return m.root(&info) +} + +func nodeAnd(a, b exprNode) exprNode { + return func(info *packetInfo) bool { return a(info) && b(info) } +} + +func nodeOr(a, b exprNode) exprNode { + return func(info *packetInfo) bool { return a(info) || b(info) } +} + +func nodeNot(n exprNode) exprNode { + return func(info *packetInfo) bool { return !n(info) } +} + +func nodeHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.srcIP == addr || info.dstIP == addr } +} + +func nodeSrcHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.srcIP == addr } +} + +func nodeDstHost(addr netip.Addr) exprNode { + return func(info *packetInfo) bool { return info.dstIP == addr } +} + +func nodePort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.srcPort == port || info.dstPort == port } +} + +func nodeSrcPort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.srcPort == port } +} + +func nodeDstPort(port uint16) exprNode { + return func(info *packetInfo) bool { return info.dstPort == port } +} + +func nodeProto(proto uint8) exprNode { + return func(info *packetInfo) bool { return info.proto == proto } +} + +func nodeFamily(family uint8) exprNode { + return func(info *packetInfo) bool { return info.family == family } +} + +func nodeNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) || prefix.Contains(info.dstIP) } +} + +func nodeSrcNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.srcIP) } +} + +func nodeDstNet(prefix netip.Prefix) exprNode { + return func(info *packetInfo) bool { return prefix.Contains(info.dstIP) } +} + +// packetInfo holds parsed header fields for filtering and display. +type packetInfo struct { + family uint8 + srcIP netip.Addr + dstIP netip.Addr + proto uint8 + srcPort uint16 + dstPort uint16 + hdrLen int +} + +func parsePacketInfo(data []byte) (packetInfo, bool) { + if len(data) < 1 { + return packetInfo{}, false + } + switch data[0] >> 4 { + case 4: + return parseIPv4Info(data) + case 6: + return parseIPv6Info(data) + default: + return packetInfo{}, false + } +} + +func parseIPv4Info(data []byte) (packetInfo, bool) { + if len(data) < 20 { + return packetInfo{}, false + } + ihl := int(data[0]&0x0f) * 4 + if ihl < 20 || len(data) < ihl { + return packetInfo{}, false + } + info := packetInfo{ + family: 4, + srcIP: netip.AddrFrom4([4]byte{data[12], data[13], data[14], data[15]}), + dstIP: netip.AddrFrom4([4]byte{data[16], data[17], data[18], data[19]}), + proto: data[9], + hdrLen: ihl, + } + if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= ihl+4 { + info.srcPort = binary.BigEndian.Uint16(data[ihl:]) + info.dstPort = binary.BigEndian.Uint16(data[ihl+2:]) + } + return info, true +} + +// parseIPv6Info parses the fixed IPv6 header. It reads the Next Header field +// directly, so packets with extension headers (hop-by-hop, routing, fragment, +// etc.) will report the extension type as the protocol rather than the final +// transport protocol. This is acceptable for a debug capture tool. +func parseIPv6Info(data []byte) (packetInfo, bool) { + if len(data) < 40 { + return packetInfo{}, false + } + var src, dst [16]byte + copy(src[:], data[8:24]) + copy(dst[:], data[24:40]) + info := packetInfo{ + family: 6, + srcIP: netip.AddrFrom16(src), + dstIP: netip.AddrFrom16(dst), + proto: data[6], + hdrLen: 40, + } + if (info.proto == protoTCP || info.proto == protoUDP) && len(data) >= 44 { + info.srcPort = binary.BigEndian.Uint16(data[40:]) + info.dstPort = binary.BigEndian.Uint16(data[42:]) + } + return info, true +} + +// ParseFilter parses a BPF-like filter expression and returns a Matcher. +// Returns nil Matcher for an empty expression (match all). +// +// Grammar (mirrors common tcpdump BPF syntax): +// +// orExpr = andExpr ("or" andExpr)* +// andExpr = unary ("and" unary)* +// unary = "not" unary | "(" orExpr ")" | term +// +// term = "host" IP | "src" target | "dst" target +// | "port" NUM | "net" PREFIX +// | "tcp" | "udp" | "icmp" | "icmp6" +// | "ip" | "ip6" | "proto" NUM +// target = "host" IP | "port" NUM | "net" PREFIX | IP +// +// Examples: +// +// host 10.0.0.1 and tcp port 443 +// not port 22 +// (host 10.0.0.1 or host 10.0.0.2) and tcp +// ip6 and icmp6 +// net 10.0.0.0/24 +// src host 10.0.0.1 or dst port 80 +func ParseFilter(expr string) (Matcher, error) { + tokens := tokenize(expr) + if len(tokens) == 0 { + return nil, nil //nolint:nilnil // nil Matcher means "match all" + } + + p := &parser{tokens: tokens} + node, err := p.parseOr() + if err != nil { + return nil, err + } + if p.pos < len(p.tokens) { + return nil, fmt.Errorf("unexpected token %q at position %d", p.tokens[p.pos], p.pos) + } + return &exprMatcher{root: node}, nil +} + +func tokenize(expr string) []string { + expr = strings.TrimSpace(expr) + if expr == "" { + return nil + } + // Split on whitespace but keep parens as separate tokens. + var tokens []string + for _, field := range strings.Fields(expr) { + tokens = append(tokens, splitParens(field)...) + } + return tokens +} + +// splitParens splits "(foo)" into "(", "foo", ")". +func splitParens(s string) []string { + var out []string + for strings.HasPrefix(s, "(") { + out = append(out, "(") + s = s[1:] + } + var trail []string + for strings.HasSuffix(s, ")") { + trail = append(trail, ")") + s = s[:len(s)-1] + } + if s != "" { + out = append(out, s) + } + out = append(out, trail...) + return out +} + +type parser struct { + tokens []string + pos int +} + +func (p *parser) peek() string { + if p.pos >= len(p.tokens) { + return "" + } + return strings.ToLower(p.tokens[p.pos]) +} + +func (p *parser) next() string { + tok := p.peek() + if tok != "" { + p.pos++ + } + return tok +} + +func (p *parser) expect(tok string) error { + got := p.next() + if got != tok { + return fmt.Errorf("expected %q, got %q", tok, got) + } + return nil +} + +func (p *parser) parseOr() (exprNode, error) { + left, err := p.parseAnd() + if err != nil { + return nil, err + } + for p.peek() == "or" { + p.next() + right, err := p.parseAnd() + if err != nil { + return nil, err + } + left = nodeOr(left, right) + } + return left, nil +} + +func (p *parser) parseAnd() (exprNode, error) { + left, err := p.parseUnary() + if err != nil { + return nil, err + } + for { + tok := p.peek() + if tok == "and" { + p.next() + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = nodeAnd(left, right) + continue + } + // Implicit AND: two atoms without "and" between them. + // Only if the next token starts an atom (not "or", ")", or EOF). + if tok != "" && tok != "or" && tok != ")" { + right, err := p.parseUnary() + if err != nil { + return nil, err + } + left = nodeAnd(left, right) + continue + } + break + } + return left, nil +} + +func (p *parser) parseUnary() (exprNode, error) { + switch p.peek() { + case "not": + p.next() + inner, err := p.parseUnary() + if err != nil { + return nil, err + } + return nodeNot(inner), nil + case "(": + p.next() + inner, err := p.parseOr() + if err != nil { + return nil, err + } + if err := p.expect(")"); err != nil { + return nil, fmt.Errorf("unclosed parenthesis") + } + return inner, nil + default: + return p.parseAtom() + } +} + +func (p *parser) parseAtom() (exprNode, error) { + tok := p.next() + if tok == "" { + return nil, fmt.Errorf("unexpected end of expression") + } + + switch tok { + case "host": + addr, err := p.parseAddr() + if err != nil { + return nil, fmt.Errorf("host: %w", err) + } + return nodeHost(addr), nil + + case "port": + port, err := p.parsePort() + if err != nil { + return nil, fmt.Errorf("port: %w", err) + } + return nodePort(port), nil + + case "net": + prefix, err := p.parsePrefix() + if err != nil { + return nil, fmt.Errorf("net: %w", err) + } + return nodeNet(prefix), nil + + case "src": + return p.parseDirTarget(true) + + case "dst": + return p.parseDirTarget(false) + + case "tcp": + return nodeProto(protoTCP), nil + case "udp": + return nodeProto(protoUDP), nil + case "icmp": + return nodeProto(protoICMP), nil + case "icmp6": + return nodeProto(protoICMPv6), nil + case "ip": + return nodeFamily(4), nil + case "ip6": + return nodeFamily(6), nil + + case "proto": + raw := p.next() + if raw == "" { + return nil, fmt.Errorf("proto: missing number") + } + n, err := strconv.Atoi(raw) + if err != nil || n < 0 || n > 255 { + return nil, fmt.Errorf("proto: invalid number %q", raw) + } + return nodeProto(uint8(n)), nil + + default: + return nil, fmt.Errorf("unknown filter keyword %q", tok) + } +} + +func (p *parser) parseDirTarget(isSrc bool) (exprNode, error) { + tok := p.peek() + switch tok { + case "host": + p.next() + addr, err := p.parseAddr() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcHost(addr), nil + } + return nodeDstHost(addr), nil + + case "port": + p.next() + port, err := p.parsePort() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcPort(port), nil + } + return nodeDstPort(port), nil + + case "net": + p.next() + prefix, err := p.parsePrefix() + if err != nil { + return nil, err + } + if isSrc { + return nodeSrcNet(prefix), nil + } + return nodeDstNet(prefix), nil + + default: + // Try as bare IP: "src 10.0.0.1" + addr, err := p.parseAddr() + if err != nil { + return nil, fmt.Errorf("expected host, port, net, or IP after src/dst, got %q", tok) + } + if isSrc { + return nodeSrcHost(addr), nil + } + return nodeDstHost(addr), nil + } +} + +func (p *parser) parseAddr() (netip.Addr, error) { + raw := p.next() + if raw == "" { + return netip.Addr{}, fmt.Errorf("missing IP address") + } + addr, err := netip.ParseAddr(raw) + if err != nil { + return netip.Addr{}, fmt.Errorf("invalid IP %q", raw) + } + return addr.Unmap(), nil +} + +func (p *parser) parsePort() (uint16, error) { + raw := p.next() + if raw == "" { + return 0, fmt.Errorf("missing port number") + } + n, err := strconv.Atoi(raw) + if err != nil || n < 1 || n > 65535 { + return 0, fmt.Errorf("invalid port %q", raw) + } + return uint16(n), nil +} + +func (p *parser) parsePrefix() (netip.Prefix, error) { + raw := p.next() + if raw == "" { + return netip.Prefix{}, fmt.Errorf("missing network prefix") + } + prefix, err := netip.ParsePrefix(raw) + if err != nil { + return netip.Prefix{}, fmt.Errorf("invalid prefix %q", raw) + } + return prefix, nil +} diff --git a/util/capture/filter_test.go b/util/capture/filter_test.go new file mode 100644 index 000000000..d5fd17566 --- /dev/null +++ b/util/capture/filter_test.go @@ -0,0 +1,263 @@ +package capture + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// buildIPv4Packet creates a minimal IPv4+TCP/UDP packet for filter testing. +func buildIPv4Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte { + t.Helper() + + hdrLen := 20 + pkt := make([]byte, hdrLen+20) + pkt[0] = 0x45 + pkt[9] = proto + + src := srcIP.As4() + dst := dstIP.As4() + copy(pkt[12:16], src[:]) + copy(pkt[16:20], dst[:]) + + pkt[20] = byte(srcPort >> 8) + pkt[21] = byte(srcPort) + pkt[22] = byte(dstPort >> 8) + pkt[23] = byte(dstPort) + + return pkt +} + +// buildIPv6Packet creates a minimal IPv6+TCP/UDP packet for filter testing. +func buildIPv6Packet(t *testing.T, srcIP, dstIP netip.Addr, proto uint8, srcPort, dstPort uint16) []byte { + t.Helper() + + pkt := make([]byte, 44) // 40 header + 4 ports + pkt[0] = 0x60 // version 6 + pkt[6] = proto // next header + + src := srcIP.As16() + dst := dstIP.As16() + copy(pkt[8:24], src[:]) + copy(pkt[24:40], dst[:]) + + pkt[40] = byte(srcPort >> 8) + pkt[41] = byte(srcPort) + pkt[42] = byte(dstPort >> 8) + pkt[43] = byte(dstPort) + + return pkt +} + +// ---- Filter struct tests ---- + +func TestFilter_Empty(t *testing.T) { + f := Filter{} + assert.True(t, f.IsEmpty()) + assert.True(t, f.Match(buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443))) +} + +func TestFilter_Host(t *testing.T) { + f := Filter{Host: netip.MustParseAddr("10.0.0.1")} + assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 1234, 80))) + assert.True(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.1"), protoTCP, 1234, 80))) + assert.False(t, f.Match(buildIPv4Packet(t, netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.3"), protoTCP, 1234, 80))) +} + +func TestFilter_InvalidPacket(t *testing.T) { + f := Filter{Host: netip.MustParseAddr("10.0.0.1")} + assert.False(t, f.Match(nil)) + assert.False(t, f.Match([]byte{})) + assert.False(t, f.Match([]byte{0x00})) +} + +func TestParsePacketInfo_IPv4(t *testing.T) { + pkt := buildIPv4Packet(t, netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("10.0.0.1"), protoTCP, 54321, 80) + info, ok := parsePacketInfo(pkt) + require.True(t, ok) + assert.Equal(t, uint8(4), info.family) + assert.Equal(t, netip.MustParseAddr("192.168.1.1"), info.srcIP) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), info.dstIP) + assert.Equal(t, uint8(protoTCP), info.proto) + assert.Equal(t, uint16(54321), info.srcPort) + assert.Equal(t, uint16(80), info.dstPort) +} + +func TestParsePacketInfo_IPv6(t *testing.T) { + pkt := buildIPv6Packet(t, netip.MustParseAddr("fd00::1"), netip.MustParseAddr("fd00::2"), protoUDP, 1234, 53) + info, ok := parsePacketInfo(pkt) + require.True(t, ok) + assert.Equal(t, uint8(6), info.family) + assert.Equal(t, netip.MustParseAddr("fd00::1"), info.srcIP) + assert.Equal(t, netip.MustParseAddr("fd00::2"), info.dstIP) + assert.Equal(t, uint8(protoUDP), info.proto) + assert.Equal(t, uint16(1234), info.srcPort) + assert.Equal(t, uint16(53), info.dstPort) +} + +// ---- ParseFilter expression tests ---- + +func matchV4(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool { + t.Helper() + return m.Match(buildIPv4Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort)) +} + +func matchV6(t *testing.T, m Matcher, srcIP, dstIP string, proto uint8, srcPort, dstPort uint16) bool { + t.Helper() + return m.Match(buildIPv6Packet(t, netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), proto, srcPort, dstPort)) +} + +func TestParseFilter_Empty(t *testing.T) { + m, err := ParseFilter("") + require.NoError(t, err) + assert.Nil(t, m, "empty expression should return nil matcher") +} + +func TestParseFilter_Atoms(t *testing.T) { + tests := []struct { + expr string + match bool + }{ + {"tcp", true}, + {"udp", false}, + {"host 10.0.0.1", true}, + {"host 10.0.0.99", false}, + {"port 443", true}, + {"port 80", false}, + {"src host 10.0.0.1", true}, + {"dst host 10.0.0.2", true}, + {"dst host 10.0.0.1", false}, + {"src port 12345", true}, + {"dst port 443", true}, + {"dst port 80", false}, + {"proto 6", true}, + {"proto 17", false}, + } + + pkt := buildIPv4Packet(t, netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), protoTCP, 12345, 443) + + for _, tt := range tests { + t.Run(tt.expr, func(t *testing.T) { + m, err := ParseFilter(tt.expr) + require.NoError(t, err) + assert.Equal(t, tt.match, m.Match(pkt)) + }) + } +} + +func TestParseFilter_And(t *testing.T) { + m, err := ParseFilter("host 10.0.0.1 and tcp port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 55555, 443), "wrong proto") + assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host") + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 80), "wrong port") +} + +func TestParseFilter_ImplicitAnd(t *testing.T) { + // "tcp port 443" = implicit AND between tcp and port 443 + m, err := ParseFilter("tcp port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoUDP, 1, 443)) +} + +func TestParseFilter_Or(t *testing.T) { + m, err := ParseFilter("port 80 or port 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 80)) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080)) +} + +func TestParseFilter_Not(t *testing.T) { + m, err := ParseFilter("not port 22") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 22)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 80)) +} + +func TestParseFilter_Parens(t *testing.T) { + m, err := ParseFilter("(port 80 or port 443) and tcp") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 443)) + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoUDP, 1, 443), "wrong proto") + assert.False(t, matchV4(t, m, "1.2.3.4", "5.6.7.8", protoTCP, 1, 8080), "wrong port") +} + +func TestParseFilter_Family(t *testing.T) { + mV4, err := ParseFilter("ip") + require.NoError(t, err) + assert.True(t, matchV4(t, mV4, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80)) + assert.False(t, matchV6(t, mV4, "fd00::1", "fd00::2", protoTCP, 1, 80)) + + mV6, err := ParseFilter("ip6") + require.NoError(t, err) + assert.False(t, matchV4(t, mV6, "10.0.0.1", "10.0.0.2", protoTCP, 1, 80)) + assert.True(t, matchV6(t, mV6, "fd00::1", "fd00::2", protoTCP, 1, 80)) +} + +func TestParseFilter_Net(t *testing.T) { + m, err := ParseFilter("net 10.0.0.0/24") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "192.168.1.1", protoTCP, 1, 80), "src in net") + assert.True(t, matchV4(t, m, "192.168.1.1", "10.0.0.200", protoTCP, 1, 80), "dst in net") + assert.False(t, matchV4(t, m, "10.0.1.1", "192.168.1.1", protoTCP, 1, 80), "neither in net") +} + +func TestParseFilter_SrcDstNet(t *testing.T) { + m, err := ParseFilter("src net 10.0.0.0/8 and dst net 192.168.0.0/16") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.1.2.3", "192.168.1.1", protoTCP, 1, 80)) + assert.False(t, matchV4(t, m, "192.168.1.1", "10.1.2.3", protoTCP, 1, 80), "reversed") +} + +func TestParseFilter_Complex(t *testing.T) { + // Real-world: capture HTTP(S) traffic to/from specific host, excluding SSH + m, err := ParseFilter("host 10.0.0.1 and (port 80 or port 443) and not port 22") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 55555, 443)) + assert.True(t, matchV4(t, m, "10.0.0.2", "10.0.0.1", protoTCP, 55555, 80)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 22, 443), "port 22 excluded") + assert.False(t, matchV4(t, m, "10.0.0.3", "10.0.0.2", protoTCP, 55555, 443), "wrong host") +} + +func TestParseFilter_IPv6Combined(t *testing.T) { + m, err := ParseFilter("ip6 and icmp6") + require.NoError(t, err) + assert.True(t, matchV6(t, m, "fd00::1", "fd00::2", protoICMPv6, 0, 0)) + assert.False(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoICMP, 0, 0), "wrong family") + assert.False(t, matchV6(t, m, "fd00::1", "fd00::2", protoTCP, 1, 80), "wrong proto") +} + +func TestParseFilter_CaseInsensitive(t *testing.T) { + m, err := ParseFilter("HOST 10.0.0.1 AND TCP PORT 443") + require.NoError(t, err) + assert.True(t, matchV4(t, m, "10.0.0.1", "10.0.0.2", protoTCP, 1, 443)) +} + +func TestParseFilter_Errors(t *testing.T) { + bad := []string{ + "badkeyword", + "host", + "port abc", + "port 99999", + "net invalid", + "(", + "(port 80", + "not", + "src", + } + for _, expr := range bad { + t.Run(expr, func(t *testing.T) { + _, err := ParseFilter(expr) + assert.Error(t, err, "should fail for %q", expr) + }) + } +} diff --git a/util/capture/pcap.go b/util/capture/pcap.go new file mode 100644 index 000000000..0a9057045 --- /dev/null +++ b/util/capture/pcap.go @@ -0,0 +1,85 @@ +package capture + +import ( + "encoding/binary" + "io" + "time" +) + +const ( + pcapMagic = 0xa1b2c3d4 + pcapVersionMaj = 2 + pcapVersionMin = 4 + // linkTypeRaw is LINKTYPE_RAW: raw IPv4/IPv6 packets without link-layer header. + linkTypeRaw = 101 + defaultSnapLen = 65535 +) + +// PcapWriter writes packets in pcap format to an underlying writer. +// The global header is written lazily on the first WritePacket call so that +// the writer can be used with unbuffered io.Pipes without deadlocking. +// It is not safe for concurrent use; callers must serialize access. +type PcapWriter struct { + w io.Writer + snapLen uint32 + headerWritten bool +} + +// NewPcapWriter creates a pcap writer. The global header is deferred until the +// first WritePacket call. +func NewPcapWriter(w io.Writer, snapLen uint32) *PcapWriter { + if snapLen == 0 { + snapLen = defaultSnapLen + } + return &PcapWriter{w: w, snapLen: snapLen} +} + +// writeGlobalHeader writes the 24-byte pcap file header. +func (pw *PcapWriter) writeGlobalHeader() error { + var hdr [24]byte + binary.LittleEndian.PutUint32(hdr[0:4], pcapMagic) + binary.LittleEndian.PutUint16(hdr[4:6], pcapVersionMaj) + binary.LittleEndian.PutUint16(hdr[6:8], pcapVersionMin) + binary.LittleEndian.PutUint32(hdr[16:20], pw.snapLen) + binary.LittleEndian.PutUint32(hdr[20:24], linkTypeRaw) + + _, err := pw.w.Write(hdr[:]) + return err +} + +// WriteHeader writes the pcap global header. Safe to call multiple times. +func (pw *PcapWriter) WriteHeader() error { + if pw.headerWritten { + return nil + } + if err := pw.writeGlobalHeader(); err != nil { + return err + } + pw.headerWritten = true + return nil +} + +// WritePacket writes a single packet record, preceded by the global header +// on the first call. +func (pw *PcapWriter) WritePacket(ts time.Time, data []byte) error { + if err := pw.WriteHeader(); err != nil { + return err + } + + origLen := uint32(len(data)) + if origLen > pw.snapLen { + data = data[:pw.snapLen] + } + + var hdr [16]byte + binary.LittleEndian.PutUint32(hdr[0:4], uint32(ts.Unix())) + binary.LittleEndian.PutUint32(hdr[4:8], uint32(ts.Nanosecond()/1000)) + binary.LittleEndian.PutUint32(hdr[8:12], uint32(len(data))) + binary.LittleEndian.PutUint32(hdr[12:16], origLen) + + if _, err := pw.w.Write(hdr[:]); err != nil { + return err + } + _, err := pw.w.Write(data) + return err +} diff --git a/util/capture/pcap_test.go b/util/capture/pcap_test.go new file mode 100644 index 000000000..c3d21ef4a --- /dev/null +++ b/util/capture/pcap_test.go @@ -0,0 +1,68 @@ +package capture + +import ( + "bytes" + "encoding/binary" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPcapWriter_GlobalHeader(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 0) + + // Header is lazy, so write a dummy packet to trigger it. + err := pw.WritePacket(time.Now(), []byte{0x45, 0, 0, 20, 0, 0, 0, 0, 64, 1, 0, 0, 10, 0, 0, 1, 10, 0, 0, 2}) + require.NoError(t, err) + + data := buf.Bytes() + require.GreaterOrEqual(t, len(data), 24, "should contain global header") + + assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4]), "magic number") + assert.Equal(t, uint16(pcapVersionMaj), binary.LittleEndian.Uint16(data[4:6]), "version major") + assert.Equal(t, uint16(pcapVersionMin), binary.LittleEndian.Uint16(data[6:8]), "version minor") + assert.Equal(t, uint32(defaultSnapLen), binary.LittleEndian.Uint32(data[16:20]), "snap length") + assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24]), "link type") +} + +func TestPcapWriter_WritePacket(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 100) + + ts := time.Date(2025, 6, 15, 12, 30, 45, 123456000, time.UTC) + payload := make([]byte, 50) + for i := range payload { + payload[i] = byte(i) + } + + err := pw.WritePacket(ts, payload) + require.NoError(t, err) + + data := buf.Bytes()[24:] // skip global header + require.Len(t, data, 16+50, "packet header + payload") + + assert.Equal(t, uint32(ts.Unix()), binary.LittleEndian.Uint32(data[0:4]), "timestamp seconds") + assert.Equal(t, uint32(123456), binary.LittleEndian.Uint32(data[4:8]), "timestamp microseconds") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[8:12]), "included length") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length") + assert.Equal(t, payload, data[16:], "packet data") +} + +func TestPcapWriter_SnapLen(t *testing.T) { + var buf bytes.Buffer + pw := NewPcapWriter(&buf, 10) + + ts := time.Now() + payload := make([]byte, 50) + + err := pw.WritePacket(ts, payload) + require.NoError(t, err) + + data := buf.Bytes()[24:] + assert.Equal(t, uint32(10), binary.LittleEndian.Uint32(data[8:12]), "included length should be truncated") + assert.Equal(t, uint32(50), binary.LittleEndian.Uint32(data[12:16]), "original length preserved") + assert.Len(t, data[16:], 10, "only snap_len bytes written") +} diff --git a/util/capture/session.go b/util/capture/session.go new file mode 100644 index 000000000..09806e10c --- /dev/null +++ b/util/capture/session.go @@ -0,0 +1,213 @@ +package capture + +import ( + "fmt" + "sync" + "sync/atomic" + "time" +) + +const defaultBufSize = 256 + +type packetEntry struct { + ts time.Time + data []byte + dir Direction +} + +// Session manages an active packet capture. Packets are offered via Offer, +// buffered in a channel, and written to configured sinks by a background +// goroutine. This keeps the hot path (FilteredDevice.Read/Write) non-blocking. +// +// The caller must call Stop when done to flush remaining packets and release +// resources. +type Session struct { + pcapW *PcapWriter + textW *TextWriter + matcher Matcher + snapLen uint32 + flushFn func() + + ch chan packetEntry + done chan struct{} + stopped chan struct{} + + closeOnce sync.Once + closed atomic.Bool + packets atomic.Int64 + bytes atomic.Int64 + dropped atomic.Int64 + started time.Time +} + +// NewSession creates and starts a capture session. At least one of +// Options.Output or Options.TextOutput must be non-nil. +func NewSession(opts Options) (*Session, error) { + if opts.Output == nil && opts.TextOutput == nil { + return nil, fmt.Errorf("at least one output sink required") + } + + snapLen := opts.SnapLen + if snapLen == 0 { + snapLen = defaultSnapLen + } + + bufSize := opts.BufSize + if bufSize <= 0 { + bufSize = defaultBufSize + } + + s := &Session{ + matcher: opts.Matcher, + snapLen: snapLen, + ch: make(chan packetEntry, bufSize), + done: make(chan struct{}), + stopped: make(chan struct{}), + started: time.Now(), + } + + if opts.Output != nil { + s.pcapW = NewPcapWriter(opts.Output, snapLen) + } + if opts.TextOutput != nil { + s.textW = NewTextWriter(opts.TextOutput, opts.Verbose, opts.ASCII) + } + + s.flushFn = buildFlushFn(opts.Output, opts.TextOutput) + + go s.run() + return s, nil +} + +// Offer submits a packet for capture. It returns immediately and never blocks +// the caller. If the internal buffer is full the packet is dropped silently. +// +// outbound should be true for packets leaving the host (FilteredDevice.Read +// path) and false for packets arriving (FilteredDevice.Write path). +// +// Offer satisfies the device.PacketCapture interface. +func (s *Session) Offer(data []byte, outbound bool) { + if s.closed.Load() { + return + } + + if s.matcher != nil && !s.matcher.Match(data) { + return + } + + captureLen := len(data) + if s.snapLen > 0 && uint32(captureLen) > s.snapLen { + captureLen = int(s.snapLen) + } + + copied := make([]byte, captureLen) + copy(copied, data) + + dir := Inbound + if outbound { + dir = Outbound + } + + select { + case s.ch <- packetEntry{ts: time.Now(), data: copied, dir: dir}: + s.packets.Add(1) + s.bytes.Add(int64(len(data))) + default: + s.dropped.Add(1) + } +} + +// Stop signals the session to stop accepting packets, drains any buffered +// packets to the sinks, and waits for the writer goroutine to exit. +// It is safe to call multiple times. +func (s *Session) Stop() { + s.closeOnce.Do(func() { + s.closed.Store(true) + close(s.done) + }) + <-s.stopped +} + +// Done returns a channel that is closed when the session's writer goroutine +// has fully exited and all buffered packets have been flushed. +func (s *Session) Done() <-chan struct{} { + return s.stopped +} + +// Stats returns current capture counters. +func (s *Session) Stats() Stats { + return Stats{ + Packets: s.packets.Load(), + Bytes: s.bytes.Load(), + Dropped: s.dropped.Load(), + } +} + +func (s *Session) run() { + defer close(s.stopped) + + for { + select { + case pkt := <-s.ch: + s.write(pkt) + case <-s.done: + s.drain() + return + } + } +} + +func (s *Session) drain() { + for { + select { + case pkt := <-s.ch: + s.write(pkt) + default: + return + } + } +} + +func (s *Session) write(pkt packetEntry) { + if s.pcapW != nil { + // Best-effort: if the writer fails (broken pipe etc.), discard silently. + _ = s.pcapW.WritePacket(pkt.ts, pkt.data) + } + if s.textW != nil { + _ = s.textW.WritePacket(pkt.ts, pkt.data, pkt.dir) + } + s.flushFn() +} + +// buildFlushFn returns a function that flushes all writers that support it. +// This covers http.Flusher and similar streaming writers. +func buildFlushFn(writers ...any) func() { + type flusher interface { + Flush() + } + + var fns []func() + for _, w := range writers { + if w == nil { + continue + } + if f, ok := w.(flusher); ok { + fns = append(fns, f.Flush) + } + } + + switch len(fns) { + case 0: + return func() { + // no writers to flush + } + case 1: + return fns[0] + default: + return func() { + for _, fn := range fns { + fn() + } + } + } +} diff --git a/util/capture/session_test.go b/util/capture/session_test.go new file mode 100644 index 000000000..ab27686c6 --- /dev/null +++ b/util/capture/session_test.go @@ -0,0 +1,144 @@ +package capture + +import ( + "bytes" + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSession_PcapOutput(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + Output: &buf, + BufSize: 16, + }) + require.NoError(t, err) + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + + sess.Offer(pkt, true) + sess.Stop() + + data := buf.Bytes() + require.Greater(t, len(data), 24, "should have global header + at least one packet") + + // Verify global header + assert.Equal(t, uint32(pcapMagic), binary.LittleEndian.Uint32(data[0:4])) + assert.Equal(t, uint32(linkTypeRaw), binary.LittleEndian.Uint32(data[20:24])) + + // Verify packet record + pktData := data[24:] + inclLen := binary.LittleEndian.Uint32(pktData[8:12]) + assert.Equal(t, uint32(len(pkt)), inclLen) + + stats := sess.Stats() + assert.Equal(t, int64(1), stats.Packets) + assert.Equal(t, int64(len(pkt)), stats.Bytes) + assert.Equal(t, int64(0), stats.Dropped) +} + +func TestSession_TextOutput(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + TextOutput: &buf, + BufSize: 16, + }) + require.NoError(t, err) + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + + sess.Offer(pkt, false) + sess.Stop() + + output := buf.String() + assert.Contains(t, output, "TCP") + assert.Contains(t, output, "10.0.0.1") + assert.Contains(t, output, "10.0.0.2") + assert.Contains(t, output, "443") + assert.Contains(t, output, "[IN TCP]") +} + +func TestSession_Filter(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{ + Output: &buf, + Matcher: &Filter{Port: 443}, + }) + require.NoError(t, err) + + pktMatch := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + pktNoMatch := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 80) + + sess.Offer(pktMatch, true) + sess.Offer(pktNoMatch, true) + sess.Stop() + + stats := sess.Stats() + assert.Equal(t, int64(1), stats.Packets, "only matching packet should be captured") +} + +func TestSession_StopIdempotent(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + + sess.Stop() + sess.Stop() // should not panic or deadlock +} + +func TestSession_OfferAfterStop(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + sess.Stop() + + pkt := buildIPv4Packet(t, + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + protoTCP, 12345, 443) + sess.Offer(pkt, true) // should not panic + + assert.Equal(t, int64(0), sess.Stats().Packets) +} + +func TestSession_Done(t *testing.T) { + var buf bytes.Buffer + sess, err := NewSession(Options{Output: &buf}) + require.NoError(t, err) + + select { + case <-sess.Done(): + t.Fatal("Done should not be closed before Stop") + default: + } + + sess.Stop() + + select { + case <-sess.Done(): + case <-time.After(time.Second): + t.Fatal("Done should be closed after Stop") + } +} + +func TestSession_RequiresOutput(t *testing.T) { + _, err := NewSession(Options{}) + assert.Error(t, err) +} diff --git a/util/capture/text.go b/util/capture/text.go new file mode 100644 index 000000000..b44bd0cad --- /dev/null +++ b/util/capture/text.go @@ -0,0 +1,638 @@ +package capture + +import ( + "encoding/binary" + "fmt" + "io" + "net/netip" + "strings" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// TextWriter writes human-readable one-line-per-packet summaries. +// It is not safe for concurrent use; callers must serialize access. +type TextWriter struct { + w io.Writer + verbose bool + ascii bool + flows map[dirKey]uint32 +} + +type dirKey struct { + src netip.AddrPort + dst netip.AddrPort +} + +// NewTextWriter creates a text formatter that writes to w. +func NewTextWriter(w io.Writer, verbose, ascii bool) *TextWriter { + return &TextWriter{ + w: w, + verbose: verbose, + ascii: ascii, + flows: make(map[dirKey]uint32), + } +} + +// tag formats the fixed-width "[DIR PROTO]" prefix with right-aligned protocol. +func tag(dir Direction, proto string) string { + return fmt.Sprintf("[%-3s %4s]", dir, proto) +} + +// WritePacket formats and writes a single packet line. +func (tw *TextWriter) WritePacket(ts time.Time, data []byte, dir Direction) error { + ts = ts.Local() + info, ok := parsePacketInfo(data) + if !ok { + _, err := fmt.Fprintf(tw.w, "%s [%-3s ?] ??? len=%d\n", + ts.Format("15:04:05.000000"), dir, len(data)) + return err + } + + timeStr := ts.Format("15:04:05.000000") + + var err error + switch info.proto { + case protoTCP: + err = tw.writeTCP(timeStr, dir, &info, data) + case protoUDP: + err = tw.writeUDP(timeStr, dir, &info, data) + case protoICMP: + err = tw.writeICMPv4(timeStr, dir, &info, data) + case protoICMPv6: + err = tw.writeICMPv6(timeStr, dir, &info, data) + default: + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err = fmt.Fprintf(tw.w, "%s %s %s > %s length %d%s\n", + timeStr, tag(dir, fmt.Sprintf("P%d", info.proto)), + info.srcIP, info.dstIP, len(data)-info.hdrLen, verbose) + } + return err +} + +func (tw *TextWriter) writeTCP(timeStr string, dir Direction, info *packetInfo, data []byte) error { + tcp := &layers.TCP{} + if err := tcp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "TCP", info, data) + } + + flags := tcpFlagsStr(tcp) + plen := len(tcp.Payload) + + // Protocol annotation + var annotation string + if plen > 0 { + annotation = annotatePayload(tcp.Payload) + } + + if !tw.verbose { + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s] length %d%s\n", + timeStr, tag(dir, "TCP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + flags, plen, annotation) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(tcp.Payload) + } + return nil + } + + relSeq, relAck := tw.relativeSeqAck(info, tcp.Seq, tcp.Ack) + + var seqStr string + if plen > 0 { + seqStr = fmt.Sprintf(", seq %d:%d", relSeq, relSeq+uint32(plen)) + } else { + seqStr = fmt.Sprintf(", seq %d", relSeq) + } + + var ackStr string + if tcp.ACK { + ackStr = fmt.Sprintf(", ack %d", relAck) + } + + var opts string + if s := formatTCPOptions(tcp.Options); s != "" { + opts = ", options [" + s + "]" + } + + verbose := tw.verboseIP(data, info.family) + + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d [%s]%s%s, win %d%s, length %d%s%s\n", + timeStr, tag(dir, "TCP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + flags, seqStr, ackStr, tcp.Window, opts, plen, annotation, verbose) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(tcp.Payload) + } + return nil +} + +func (tw *TextWriter) writeUDP(timeStr string, dir Direction, info *packetInfo, data []byte) error { + udp := &layers.UDP{} + if err := udp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "UDP", info, data) + } + + plen := len(udp.Payload) + + // DNS replaces the entire line format + if plen > 0 && isDNSPort(info.srcPort, info.dstPort) { + if s := formatDNSPayload(udp.Payload); s != "" { + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d %s%s\n", + timeStr, tag(dir, "UDP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + s, verbose) + return err + } + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d%s\n", + timeStr, tag(dir, "UDP"), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + plen, verbose) + if err != nil { + return err + } + if tw.ascii && plen > 0 { + return tw.writeASCII(udp.Payload) + } + return nil +} + +func (tw *TextWriter) writeICMPv4(timeStr string, dir Direction, info *packetInfo, data []byte) error { + icmp := &layers.ICMPv4{} + if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "ICMP", info, data) + } + + var detail string + if icmp.TypeCode.Type() == layers.ICMPv4TypeEchoRequest || icmp.TypeCode.Type() == layers.ICMPv4TypeEchoReply { + detail = fmt.Sprintf("%s, id %d, seq %d", icmp.TypeCode.String(), icmp.Id, icmp.Seq) + } else { + detail = icmp.TypeCode.String() + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n", + timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, detail, len(data)-info.hdrLen, verbose) + return err +} + +func (tw *TextWriter) writeICMPv6(timeStr string, dir Direction, info *packetInfo, data []byte) error { + icmp := &layers.ICMPv6{} + if err := icmp.DecodeFromBytes(data[info.hdrLen:], gopacket.NilDecodeFeedback); err != nil { + return tw.writeFallback(timeStr, dir, "ICMP", info, data) + } + + var verbose string + if tw.verbose { + verbose = tw.verboseIP(data, info.family) + } + _, err := fmt.Fprintf(tw.w, "%s %s %s > %s %s, length %d%s\n", + timeStr, tag(dir, "ICMP"), info.srcIP, info.dstIP, icmp.TypeCode.String(), len(data)-info.hdrLen, verbose) + return err +} + +func (tw *TextWriter) writeFallback(timeStr string, dir Direction, proto string, info *packetInfo, data []byte) error { + _, err := fmt.Fprintf(tw.w, "%s %s %s:%d > %s:%d length %d\n", + timeStr, tag(dir, proto), + info.srcIP, info.srcPort, info.dstIP, info.dstPort, + len(data)-info.hdrLen) + return err +} + +func (tw *TextWriter) verboseIP(data []byte, family uint8) string { + return fmt.Sprintf(", ttl %d, id %d, iplen %d", + ipTTL(data, family), ipID(data, family), len(data)) +} + +// relativeSeqAck returns seq/ack relative to the first seen value per direction. +func (tw *TextWriter) relativeSeqAck(info *packetInfo, seq, ack uint32) (relSeq, relAck uint32) { + fwd := dirKey{ + src: netip.AddrPortFrom(info.srcIP, info.srcPort), + dst: netip.AddrPortFrom(info.dstIP, info.dstPort), + } + rev := dirKey{ + src: netip.AddrPortFrom(info.dstIP, info.dstPort), + dst: netip.AddrPortFrom(info.srcIP, info.srcPort), + } + + if isn, ok := tw.flows[fwd]; ok { + relSeq = seq - isn + } else { + tw.flows[fwd] = seq + } + + if isn, ok := tw.flows[rev]; ok { + relAck = ack - isn + } else { + relAck = ack + } + + return relSeq, relAck +} + +// writeASCII prints payload bytes as printable ASCII. +func (tw *TextWriter) writeASCII(payload []byte) error { + if len(payload) == 0 { + return nil + } + buf := make([]byte, len(payload)) + for i, b := range payload { + switch { + case b >= 0x20 && b < 0x7f: + buf[i] = b + case b == '\n' || b == '\r' || b == '\t': + buf[i] = b + default: + buf[i] = '.' + } + } + _, err := fmt.Fprintf(tw.w, "%s\n", buf) + return err +} + +// --- TCP helpers --- + +func ipTTL(data []byte, family uint8) uint8 { + if family == 4 && len(data) > 8 { + return data[8] + } + if family == 6 && len(data) > 7 { + return data[7] + } + return 0 +} + +func ipID(data []byte, family uint8) uint16 { + if family == 4 && len(data) >= 6 { + return binary.BigEndian.Uint16(data[4:6]) + } + return 0 +} + +func tcpFlagsStr(tcp *layers.TCP) string { + var buf [6]byte + n := 0 + if tcp.SYN { + buf[n] = 'S' + n++ + } + if tcp.FIN { + buf[n] = 'F' + n++ + } + if tcp.RST { + buf[n] = 'R' + n++ + } + if tcp.PSH { + buf[n] = 'P' + n++ + } + if tcp.ACK { + buf[n] = '.' + n++ + } + if tcp.URG { + buf[n] = 'U' + n++ + } + if n == 0 { + return "none" + } + return string(buf[:n]) +} + +func formatTCPOptions(opts []layers.TCPOption) string { + var parts []string + for _, opt := range opts { + switch opt.OptionType { + case layers.TCPOptionKindEndList: + return strings.Join(parts, ",") + case layers.TCPOptionKindNop: + parts = append(parts, "nop") + case layers.TCPOptionKindMSS: + if len(opt.OptionData) == 2 { + parts = append(parts, fmt.Sprintf("mss %d", binary.BigEndian.Uint16(opt.OptionData))) + } + case layers.TCPOptionKindWindowScale: + if len(opt.OptionData) == 1 { + parts = append(parts, fmt.Sprintf("wscale %d", opt.OptionData[0])) + } + case layers.TCPOptionKindSACKPermitted: + parts = append(parts, "sackOK") + case layers.TCPOptionKindSACK: + blocks := len(opt.OptionData) / 8 + parts = append(parts, fmt.Sprintf("sack %d", blocks)) + case layers.TCPOptionKindTimestamps: + if len(opt.OptionData) == 8 { + tsval := binary.BigEndian.Uint32(opt.OptionData[0:4]) + tsecr := binary.BigEndian.Uint32(opt.OptionData[4:8]) + parts = append(parts, fmt.Sprintf("TS val %d ecr %d", tsval, tsecr)) + } + } + } + return strings.Join(parts, ",") +} + +// --- Protocol annotation --- + +// annotatePayload returns a protocol annotation string for known application protocols. +func annotatePayload(payload []byte) string { + if len(payload) < 4 { + return "" + } + + s := string(payload) + + // SSH banner: "SSH-2.0-OpenSSH_9.6\r\n" + if strings.HasPrefix(s, "SSH-") { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 256 { + return ": " + s[:end] + } + } + + // TLS records + if ann := annotateTLS(payload); ann != "" { + return ": " + ann + } + + // HTTP request or response + for _, method := range [...]string{"GET ", "POST ", "PUT ", "DELETE ", "HEAD ", "PATCH ", "OPTIONS ", "CONNECT "} { + if strings.HasPrefix(s, method) { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 { + return ": " + s[:end] + } + } + } + if strings.HasPrefix(s, "HTTP/") { + if end := strings.IndexByte(s, '\r'); end > 0 && end < 200 { + return ": " + s[:end] + } + } + + return "" +} + +// annotateTLS returns a description for TLS handshake and alert records. +func annotateTLS(data []byte) string { + if len(data) < 6 { + return "" + } + + switch data[0] { + case 0x16: + return annotateTLSHandshake(data) + case 0x15: + return annotateTLSAlert(data) + } + return "" +} + +func annotateTLSHandshake(data []byte) string { + if len(data) < 10 { + return "" + } + switch data[5] { + case 0x01: + if sni := extractSNI(data); sni != "" { + return "TLS ClientHello SNI=" + sni + } + return "TLS ClientHello" + case 0x02: + return "TLS ServerHello" + } + return "" +} + +func annotateTLSAlert(data []byte) string { + if len(data) < 7 { + return "" + } + severity := "warning" + if data[5] == 2 { + severity = "fatal" + } + return fmt.Sprintf("TLS Alert %s %s", severity, tlsAlertDesc(data[6])) +} + +func tlsAlertDesc(code byte) string { + switch code { + case 0: + return "close_notify" + case 10: + return "unexpected_message" + case 40: + return "handshake_failure" + case 42: + return "bad_certificate" + case 43: + return "unsupported_certificate" + case 44: + return "certificate_revoked" + case 45: + return "certificate_expired" + case 48: + return "unknown_ca" + case 49: + return "access_denied" + case 50: + return "decode_error" + case 70: + return "protocol_version" + case 80: + return "internal_error" + case 86: + return "inappropriate_fallback" + case 90: + return "user_canceled" + case 112: + return "unrecognized_name" + default: + return fmt.Sprintf("alert(%d)", code) + } +} + +// extractSNI parses a TLS ClientHello and returns the SNI server name. +func extractSNI(data []byte) string { + if len(data) < 6 || data[0] != 0x16 { + return "" + } + recordLen := int(binary.BigEndian.Uint16(data[3:5])) + handshake := data[5:] + if len(handshake) > recordLen { + handshake = handshake[:recordLen] + } + + if len(handshake) < 4 || handshake[0] != 0x01 { + return "" + } + hsLen := int(handshake[1])<<16 | int(handshake[2])<<8 | int(handshake[3]) + body := handshake[4:] + if len(body) > hsLen { + body = body[:hsLen] + } + + extPos := clientHelloExtensionsOffset(body) + if extPos < 0 { + return "" + } + return findSNIExtension(body, extPos) +} + +// clientHelloExtensionsOffset returns the byte offset where extensions begin +// within the ClientHello body, or -1 if the body is too short. +func clientHelloExtensionsOffset(body []byte) int { + if len(body) < 38 { + return -1 + } + pos := 34 + + if pos >= len(body) { + return -1 + } + pos += 1 + int(body[pos]) // session ID + + if pos+2 > len(body) { + return -1 + } + pos += 2 + int(binary.BigEndian.Uint16(body[pos:pos+2])) // cipher suites + + if pos >= len(body) { + return -1 + } + pos += 1 + int(body[pos]) // compression methods + + if pos+2 > len(body) { + return -1 + } + return pos +} + +func findSNIExtension(body []byte, pos int) string { + extLen := int(binary.BigEndian.Uint16(body[pos : pos+2])) + pos += 2 + extEnd := pos + extLen + if extEnd > len(body) { + extEnd = len(body) + } + + for pos+4 <= extEnd { + extType := binary.BigEndian.Uint16(body[pos : pos+2]) + eLen := int(binary.BigEndian.Uint16(body[pos+2 : pos+4])) + pos += 4 + if pos+eLen > extEnd { + break + } + if extType == 0 && eLen >= 5 { + nameLen := int(binary.BigEndian.Uint16(body[pos+3 : pos+5])) + if pos+5+nameLen <= extEnd { + return string(body[pos+5 : pos+5+nameLen]) + } + } + pos += eLen + } + return "" +} + +func isDNSPort(src, dst uint16) bool { + return src == 53 || dst == 53 || src == 5353 || dst == 5353 +} + +// formatDNSPayload parses DNS and returns a tcpdump-style summary. +func formatDNSPayload(payload []byte) string { + d := &layers.DNS{} + if err := d.DecodeFromBytes(payload, gopacket.NilDecodeFeedback); err != nil { + return "" + } + + rd := "" + if d.RD { + rd = "+" + } + + if !d.QR { + return formatDNSQuery(d, rd, len(payload)) + } + return formatDNSResponse(d, rd, len(payload)) +} + +func formatDNSQuery(d *layers.DNS, rd string, plen int) string { + if len(d.Questions) == 0 { + return fmt.Sprintf("%04x%s (%d)", d.ID, rd, plen) + } + q := d.Questions[0] + return fmt.Sprintf("%04x%s %s? %s. (%d)", d.ID, rd, q.Type, q.Name, plen) +} + +func formatDNSResponse(d *layers.DNS, rd string, plen int) string { + anCount := d.ANCount + nsCount := d.NSCount + arCount := d.ARCount + + if d.ResponseCode != layers.DNSResponseCodeNoErr { + return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen) + } + + if anCount > 0 && len(d.Answers) > 0 { + rr := d.Answers[0] + if rdata := shortRData(&rr); rdata != "" { + return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen) + } + } + + return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen) +} + +func shortRData(rr *layers.DNSResourceRecord) string { + switch rr.Type { + case layers.DNSTypeA, layers.DNSTypeAAAA: + if rr.IP != nil { + return rr.IP.String() + } + case layers.DNSTypeCNAME: + if len(rr.CNAME) > 0 { + return string(rr.CNAME) + "." + } + case layers.DNSTypePTR: + if len(rr.PTR) > 0 { + return string(rr.PTR) + "." + } + case layers.DNSTypeNS: + if len(rr.NS) > 0 { + return string(rr.NS) + "." + } + case layers.DNSTypeMX: + return fmt.Sprintf("%d %s.", rr.MX.Preference, rr.MX.Name) + case layers.DNSTypeTXT: + if len(rr.TXTs) > 0 { + return fmt.Sprintf("%q", string(rr.TXTs[0])) + } + case layers.DNSTypeSRV: + return fmt.Sprintf("%d %d %d %s.", rr.SRV.Priority, rr.SRV.Weight, rr.SRV.Port, rr.SRV.Name) + } + return "" +}