Compare commits

...

82 Commits

Author SHA1 Message Date
Eduard Gert
842ef0d657 update macos icon 2026-05-11 15:40:04 +02:00
Eduard Gert
439f44c6b4 merge 2026-05-11 15:16:41 +02:00
Eduard Gert
b5a970155b Merge branch 'ui-refactor' into ui-refactor-ui 2026-05-11 15:15:11 +02:00
Eduard Gert
686e0d97f2 update Assets.car 2026-05-11 14:51:05 +02:00
Eduard Gert
0c287b6f4d fix vite dev server 2026-05-11 14:48:37 +02:00
Eduard Gert
f7f5946910 update components 2026-05-11 14:26:10 +02:00
Zoltan Papp
7a9f5a734f Merge branch 'main' into ui-refactor
Port IPv6 overlay support (#5631) into the Wails UI:
- Add DisableIPv6 config toggle to Settings (NetworkTab + services)
- Filter ::/0 alongside 0.0.0.0/0 as an exit-node route
- Suppress duplicate v6 default-route notifications in tray
2026-05-11 14:10:12 +02:00
Eduard Gert
1aae067aaa add settings skeleton 2026-05-11 13:58:41 +02:00
Zoltan Papp
28a7eba756 [client/ui] Remove unused xembed_host_other.go stub 2026-05-11 13:54:17 +02:00
Zoltan Papp
8841b950a2 [client/server] Stop retry loop after PermissionDenied login
Without marking the error as backoff.Permanent the outer retry re-enters
connect(), which resets the daemon state from NeedsLogin to Connecting
and makes the tray flicker between the two until the user logs in.
2026-05-11 13:43:53 +02:00
Eduard Gert
0c2702c0d7 update height and wording 2026-05-11 13:30:05 +02:00
Zoltan Papp
b43a09a1c7 [client/ui] Add tray icon for needs-login/login-failed states
The tray now switches to a dedicated lock icon when the daemon reports
NeedsLogin, SessionExpired or LoginFailed — the latter mirrors the CLI,
which groups these three statuses together as "needs authentication"
and prints the same "Run netbird up" prompt. The macOS template variant
reuses the existing error-macos PNG because the project's macOS tray
PNGs use a 2-color (black + transparent) convention that rsvg-convert
of the badge-style SVG sources can't reproduce. The earlier badge-style
SVG sketches in assets/svg/ are removed (they were marked as reference
only and never matched the shipping PNG design).
2026-05-11 13:22:30 +02:00
Zoltan Papp
595dfbb6f1 [client/ui] Distinguish "daemon not running" tray state
The status stream emits a synthetic StatusDaemonUnavailable when the
gRPC client or stream cannot be established, fired once per outage and
cleared on the next real snapshot. The tray maps it to a "Not running"
status label, switches the icon to the error variant, hides
Connect/Disconnect (neither would work without the daemon), and
disables Settings, Networks and Create Debug Bundle so the user is not
routed to pages that would just fail to load.
2026-05-11 12:22:47 +02:00
Zoltan Papp
7f560df9be [client/ui] Tray menu opens on click; hide window at startup
Left-click on the tray icon now opens the menu on every platform — the
window is reached through a new "Open NetBird" entry. Only the action
that matches the current daemon state is shown: Connect when
disconnected, Disconnect when connected. The main window starts hidden
and is only surfaced via the tray, single-instance launch, or daemon
events.
2026-05-11 12:01:46 +02:00
Zoltán Papp
09052949a2 [client/ui] Finish ui-wails rename (import paths, fyne deps)
Follow-up to the rename commit: the previous commit moved the files but
the post-mv string substitutions (Go imports, frontend bindings, CI
config paths) were not re-staged so they slipped through. This commit
applies those edits and removes the fyne dependencies from go.mod/go.sum
now that the legacy fyne UI is gone.
2026-05-11 11:33:35 +02:00
Zoltán Papp
9aef31ff53 [client/ui] Replace fyne UI with Wails (rename ui-wails to ui)
Removes the legacy fyne-based client/ui implementation and renames the
Wails replacement (client/ui-wails) to take its place at client/ui. Go
imports, frontend bindings, CI workflows, goreleaser configs and the
windows .syso icon path are updated to follow the rename.
2026-05-11 11:20:22 +02:00
Zoltán Papp
08f52f4517 [client/server] Allow clearing pre-shared key via SetConfig
The daemon ignored an empty OptionalPreSharedKey, so a UI/CLI request to
clear the pre-shared key was silently dropped. Pass the pointer through
unconditionally — profilemanager already handles the empty-string case.
2026-05-11 11:02:39 +02:00
Viktor Liu
a4114a5e45 [client] Skip DNS upstream failover on definitive EDE (#6089) 2026-05-11 10:00:23 +02:00
Viktor Liu
6b08e89c7b [relay] Preserve non-standard port in WS dialer URL prep (#6061) 2026-05-11 09:59:33 +02:00
Viktor Liu
a852b3bd34 [client, proxy] Harden uspfilter conntrack and share TCP relay (#5936) 2026-05-11 09:59:13 +02:00
Viktor Liu
afb83b3049 [client] Use unique temp file and clean up on failure when writing ssh config (#6064) 2026-05-11 09:58:49 +02:00
Eduard Gert
18e3b5dd32 fix about 2026-05-11 09:37:14 +02:00
Eduard Gert
f3f9704c6f update about 2026-05-08 17:55:41 +02:00
Eduard Gert
4c3d4effbd update troubleshooting 2026-05-08 17:18:25 +02:00
Nicolas Frati
e89aad09f5 [management] Enable MFA for local users (#5804)
* wip: totp for local users

* fix providers not getting populated

* polished UI and fix post_login_redirect_uri

* fix: make sure logout is only prompted from oidc flow

Signed-off-by: jnfrati <nicofrati@gmail.com>

* update templates

Signed-off-by: jnfrati <nicofrati@gmail.com>

* deps: update dex dependency

Signed-off-by: jnfrati <nicofrati@gmail.com>

* fix qube issues

Signed-off-by: jnfrati <nicofrati@gmail.com>

* replace window with globalThis on home html

Signed-off-by: jnfrati <nicofrati@gmail.com>

* fixed coderabbit comments

Signed-off-by: jnfrati <nicofrati@gmail.com>

* debug

* remove unused config and rename totp issuer

* deps: update dex reference to latest

* add dashboard post logout redirect uri to embedded config

* implemented api for mfa configuration

* update docs and config parsing

* catch error on idp manager init mfa

* fix tests

* Add remember me  for MFA

* Add cookie encryption and session share between tabs

* fixed logout showing non actionable error and session cookie encription key

* fixed missing mfa settings on sql query for account

* fix code index for mfa activity

---------

Signed-off-by: jnfrati <nicofrati@gmail.com>
Co-authored-by: braginini <bangvalo@gmail.com>
2026-05-08 16:31:20 +02:00
Eduard Gert
3953fee5a4 update ssh and advanced settings tabs 2026-05-08 10:57:31 +02:00
Eduard Gert
adeaa49cda update switch 2026-05-07 17:27:56 +02:00
Eduard Gert
2c5d52a1bf update wording 2026-05-07 17:19:56 +02:00
Eduard Gert
70a755fbae add general settings 2026-05-07 16:47:52 +02:00
Maycon Santos
7da94a4956 [misc] Update CONTRIBUTING.md (#6076) 2026-05-07 16:16:48 +02:00
Pascal Fischer
39eac377e4 [management] add update reason to buffered calls (#6103) 2026-05-07 15:55:59 +02:00
Eduard Gert
559da5d5b9 refactor 2026-05-07 15:00:36 +02:00
Eduard Gert
614ee11ac7 update CFBundleDisplayName 2026-05-07 14:19:34 +02:00
Eduard Gert
85080afa59 use new mac style icons 2026-05-07 14:14:26 +02:00
Zoltán Papp
a5cc8da054 [client] Pre-seed CustomActivator CLSID under HKCU AppUserModelId\NetBird
The Wails notifications service reads HKCU\Software\Classes\AppUserModelId\
<AppName>\CustomActivator on first startup; if present it uses that GUID
as the toast activator CLSID, otherwise it generates a fresh UUID and
writes it back. Without an installer-supplied value the per-machine GUID
diverges from the ToastActivatorCLSID baked into the Start Menu and
Desktop shortcuts, and the COM activator never fires when a toast is
clicked. Seed the same CLSID the shortcuts use so the two sides match.
2026-05-07 13:00:51 +02:00
Zoltán Papp
a4fd5a78b4 [client/ui-wails] Set application Name to "NetBird" for Windows toasts
Windows uses application.Options.Name as the toast AppUserModelID and as
the registry path the Wails notifier reads/writes its CustomActivator
under (HKCU\Software\Classes\AppUserModelId\<Name>). The MSI installer
seeds those under "NetBird"; with the previous "netbird-ui" Name the app
would have written under a different identity and the toast activator
CLSID the installer pre-registers would have been orphaned.
2026-05-07 12:59:01 +02:00
Eduard Gert
062a183e4e update settings nav 2026-05-07 12:40:04 +02:00
Viktor Liu
205ebcfda2 [management, client] Add IPv6 overlay support (#5631) 2026-05-07 11:33:37 +02:00
Eduard Gert
a2be41caf8 add about setting 2026-05-07 11:24:11 +02:00
Zoltán Papp
5b70989e3e [client/ui-wails] Make /update page faithful to the legacy auto-update dialog
Adds the missing info line ("Your client version is older than the
auto-update version set in Management. Updating client to: <version>.")
and replaces the spinner with the legacy 1-second dot animation
(Updating./.../...). Terminal-state wording now matches the Fyne UI
exactly: 15 min timeout, canceled, and "Update failed: <err>".

Ports mapInstallError from client/ui/update.go so daemon errors that
embed "deadline exceeded" / "canceled" hit the right branch instead of
falling through as a generic failure.

Detects the daemon dropping mid-upgrade (the legacy success signal):
if GetInstallerResult fails for 5s straight, call the new Update.Quit
service method to exit, mirroring app.Quit() in showInstallerResult.
2026-05-07 10:35:18 +02:00
Zoltán Papp
d324a5ff48 [ci] Stub frontend/dist before lint so the Wails embed pattern matches
client/ui-wails/main.go embeds all:frontend/dist, which is produced by
the frontend build and gitignored. Lint runs don't build the frontend,
so the directory is missing in CI and golangci-lint fails the typecheck.
Create a placeholder file before linting so the embed has something to
match.
2026-05-07 10:23:02 +02:00
Eduard Gert
debb558aa3 wip 2026-05-07 09:57:14 +02:00
Zoltán Papp
cce80f8276 [client/ui-wails] Drop dead freebsd branches in services/connection.go
The file's build constraint excludes freebsd, so the freebsd cases in
IsUnixDesktopClient and OpenURL were unreachable — staticcheck (SA4032)
fails the pre-push lint. Linux is the only Unix-desktop GOOS this
package compiles for, so collapse both checks accordingly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 18:00:51 +02:00
Zoltán Papp
05ee4e52b8 [client/ui-wails] Make the SSO login flow recoverable from a stuck state
A pending WaitSSOLogin parks the daemon on an OAuth UserCode forever
once the user closes the browser without completing the flow. The
frontend can't unblock that on its own — it needs the daemon to fire
its own actCancel(). Three fixes work together:

- Login() now issues a Down() before kicking off the new flow so a
  previously-stuck WaitSSOLogin is unwedged before we ask the daemon
  for fresh OAuth info.
- The Login page's Cancel button calls Down() before navigating away,
  so abandoning the flow mid-browser actually settles the daemon's
  in-flight WaitSSOLogin instead of leaving it pinned.
- Status keeps the Login button visible whenever we aren't Connected
  (including Connecting), so a UI restart that finds the daemon stuck
  in Connecting still has a one-click recovery path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 17:59:50 +02:00
Zoltán Papp
bb2bf673a0 [client/ui-wails] Wire up the SSO login flow end-to-end
Mirror the Fyne client's login path: the daemon Login RPC now defaults
ProfileName/Username from GetActiveProfile + the OS user and sets
IsUnixDesktopClient on Linux/FreeBSD so the daemon picks the SSO
browser flow. A new OpenURL service launches the user's default
browser via xdg-open / open / rundll32 (Fyne's openURL helper) — the
embedded WebKit's window.open silently fails for external URLs.

The frontend gains a Login page that drives the full Login →
window.open via OpenURL → WaitSSOLogin → Up sequence with progress
states. Status surfaces a Login button while the daemon reports
NeedsLogin/SessionExpired, and the tray's status row stops being a
purely-decorative label: it becomes a clickable Login entry whenever
re-authentication is required.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 17:48:47 +02:00
Zoltán Papp
91c745e5e8 [client/ui-wails] Tear down the whole tray popup tree on focus loss
Replace the per-submenu focus-out handler with a shared idle-deferred
recheck: when any popup loses focus, ask after the next event-loop
turn whether *any* of our popups still owns toplevel focus. If none
does, the user clicked outside the menu tree, so close every popup at
once instead of leaking the parent.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 17:19:55 +02:00
Zoltán Papp
68c38247f1 [client/ui-wails] Add submenu support to the XEmbed tray popup
Recursively walk dbusmenu children-display="submenu" entries when
flattening the SNI menu so the GTK popup can render nested items.
The C side renders submenu folders as labeled buttons that open a
child popup window aligned to the anchor row, kept on-screen with
horizontal flipping; the top-level popup no longer self-destructs
when focus transfers to one of its own submenus.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 17:17:54 +02:00
Zoltan Papp
f23aaa9ae7 [client] iOS: structured ResolvedIPs collection for domain routes (#6090)
* [client] iOS: structured ResolvedIPs collection for domain routes

Replace comma-joined ResolvedIPs string with a gomobile-friendly
ResolvedIPs collection (Add/Get/Size), mirroring the Android bridge
in client/android/network_domains.go.

This allows the iOS app to match domain-route resolved IPs against
connected peer routes without parsing CSV strings, fixing the route
status indicator for dynamic (DNS) routes.

* [client] iOS: align dynamic route exposure with Android bridge

For dynamic (DNS) routes the Swift side previously received
"invalid Prefix" as the Network value, forcing UI code to special-case
that sentinel. The Android bridge uses Domains.SafeString() instead so
peer.routes entries (which also derive from Domains.SafeString()) match
directly. Mirror that here.

Also fix the resolved IP lookup: resolvedDomains is keyed by the
resolved domain (e.g. api.ipify.org), not the configured pattern
(e.g. *.ipify.org). Group entries by ParentDomain like the daemon does
in client/server/network.go, so wildcard route patterns get their
resolved IPs populated.
2026-05-06 17:14:11 +02:00
Zoltán Papp
8b8f38de1b [client/ui-wails] Show GUI and daemon versions in the About submenu
Restore the legacy Fyne UI's two disabled "GUI: x.y.z" / "Daemon: a.b.c"
entries under About so users (and support) can read the running
versions from the tray. The GUI line is baked in at build time via
version.NetbirdVersion() — the same -ldflags chain the rest of the
repo uses. The daemon line starts as "—" and is rewritten in
applyStatus on every Status snapshot whose DaemonVersion differs from
the last one we recorded, so a daemon restart with a new build
(e.g. after an enforced update) updates the menu automatically.

Drive-by: rename the local variable that shadowed the version package
in handleUpdate so the import resolves cleanly.
2026-05-06 16:55:52 +02:00
Zoltán Papp
2b272e74c8 [client/ui-wails] In-process StatusNotifierWatcher + XEmbed tray bridge
Wails3's Linux systray hands the icon off to whatever process owns
org.kde.StatusNotifierWatcher on the session bus. Bare WMs (Fluxbox,
OpenBox, i3, dwm, sway, vanilla GNOME without the AppIndicator
extension) ship no watcher, so the icon registration silently fails
and the tray never appears — leaving a tray-only app like NetBird
unreachable.

Add a Linux-only watcher fallback that claims the watcher name when
nobody else does, plus an XEmbed bridge so legacy X11 system trays
(_NET_SYSTEM_TRAY_S0) can still render the icon. Both no-op on other
platforms via build tags.

Pieces:
- tray_watcher_linux.go: claims org.kde.StatusNotifierWatcher on a
  private session bus, exports the bare RegisterStatusNotifierItem /
  RegisterStatusNotifierHost surface, and spins up an XEmbed host per
  registered SNI item.
- xembed_host_linux.go: per-item event loop. Polls X11 events with a
  50ms ticker, listens for the SNI NewIcon signal, dispatches Activate
  / context menu through dbusmenu (com.canonical.dbusmenu).
- xembed_tray_linux.{c,h}: the X11/cairo native bits. Window is created
  with CopyFromParent visual + ParentRelative background so transparent
  pixels show the toolbar beneath instead of solid black on 24-bit
  trays. cairo paints the IconPixmap with OVER blending so per-pixel
  alpha is honoured against the parent-relative base. GTK3 owns the
  context-menu popup; menu items round-trip through dbusmenu Event.
- tray_linux.go: forces WEBKIT_DISABLE_DMABUF_RENDERER=1 in init() so
  developers running `task dev` / launching the binary directly get the
  same software rendering path the .desktop launcher already enables;
  the deb/rpm Exec wrapper covers installed users.
- tray_watcher_other.go and xembed_host_other.go: build-tag stubs so
  main.go's startStatusNotifierWatcher() compiles on every platform.
- main.go: calls startStatusNotifierWatcher() before NewTray so the
  Wails systray's RegisterStatusNotifierItem call hits a watcher we
  control on bare WMs.
- build/linux/netbird-ui.desktop: regenerated by `task build` to wrap
  the dev launcher's Exec line with the WEBKIT_DISABLE_DMABUF_RENDERER
  env, matching what the tray_linux.go init does at runtime.

Adapted from work originally prototyped on the prototype/ui-wails branch.

Tested on Fluxbox (Debian 13): the icon appears in the slit/toolbar with
the toolbar's background showing through transparent pixels, left-click
opens the window, right-click brings up the GTK popup of the dbusmenu
items.
2026-05-06 16:47:35 +02:00
Zoltán Papp
e6cbf30415 [client/ui-wails] Surface daemon SessionExpired in the tray
Port the Fyne UI's onSessionExpire 1:1 to the Wails tray so an SSO token
expiry no longer leaves the user staring at a stale peer list. When
applyStatus sees the transition into the daemon's StatusSessionExpired,
fire a single OS notification (the lastStatus guard rate-limits it to
the transition itself, mirroring the Fyne sendNotification flag) and
bring the main window forward on the /login route so the frontend can
drive the renewed SSO flow. The Fyne client achieved the same end with
a runSelfCommand "login-url" helper; here the window is already
in-process so we route to it directly.
2026-05-06 15:57:34 +02:00
Zoltán Papp
490b60ad0e [ci] Suppress typecheck on the ui-wails embed instead of skipping main.go
The previous attempt added client/ui-wails/main.go to the file path
exclude list, but golangci-lint v2's path filter only suppresses
issues from rule-based linters; the typecheck pre-pass that compiles
the package still runs and fails with "pattern all:frontend/dist: no
matching files found" before any rule fires.

Replace the path-level skip with a targeted exclusions.rules entry
that matches just that diagnostic on just that file. The rest of
client/ui-wails (services/, tray.go, grpc.go, ...) keeps being linted
normally.

Validated locally by deleting frontend/dist and running
`golangci-lint run client/ui-wails/...` — 0 issues with this config.
2026-05-06 15:50:14 +02:00
Eduard Gert
553be144b4 add setting 2026-05-06 14:21:01 +02:00
Eduard Gert
c3f9514182 wip 2026-05-06 10:47:40 +02:00
Zoltán Papp
a8812d5fb1 Merge remote-tracking branch 'origin/main' into ui-refactor
# Conflicts:
#	go.mod
#	go.sum
2026-05-05 15:41:59 +02:00
Zoltán Papp
6f93cf6ac3 [client/ui-wails] Group Tray's services into a TrayServices struct
NewTray's eight-parameter signature crossed Sonar's seven-parameter
threshold once Update joined the dependency list. Bundle the six service
pointers (Connection, Settings, Profiles, Peers, Notifier, Update) into
a TrayServices struct, leaving NewTray with three arguments — the two
Wails platform handles plus the service bag. Tray.svc replaces the
individual fields; call sites use t.svc.Connection etc.

Adding another service later is now a one-line struct change instead
of a NewTray signature break.
2026-05-05 15:37:25 +02:00
Zoltán Papp
18909390c2 [ci] Use go list -e so the ui-wails embed doesn't blank the test list
The previous fix added /client/ui-wails to the grep -v / Where-Object
filter, but go list aborts at the first broken package and emits an
empty stdout when client/ui-wails/main.go's //go:embed all:frontend/dist
fails to resolve. The command substitution then expands to nothing, and
`go test` falls back to the repo root — which has no Go files and fails
the job.

`go list -e` keeps listing remaining packages after a parse error, so
the existing path-based filter now actually does its job.

Touches all three test workflows (Linux native + docker, Darwin, Windows).
2026-05-05 15:30:40 +02:00
Zoltán Papp
b3eb5f2453 [ci] Skip lockfiles in codespell
pnpm-lock.yaml and package-lock.json embed package hashes that look
like English words to codespell (e.g. "nD" -> "and"), causing false
positives that can't be fixed because the lockfile is auto-generated.
Add the standard lockfile patterns to the skip list alongside the
existing go.mod/go.sum/proxy-web entries.
2026-05-05 15:15:15 +02:00
Zoltán Papp
dc02542a9e [ci] Skip client/ui-wails/main.go in golangci-lint
main.go uses //go:embed all:frontend/dist, which fails the typecheck
phase when frontend/dist is empty (the release pipeline populates it
via `pnpm build`; the lint workflow does not). Excluding just main.go
keeps the rest of the package — services/, tray.go, grpc.go, the
signal handlers — in scope.
2026-05-05 15:12:49 +02:00
Zoltán Papp
0c136fffb9 [ci] Add sonar-project.properties to exclude the Wails React frontend
Sonar's default scanner picks up TypeScript / JSX from the frontend
tree but applies rules that don't fit a UI codebase reviewed visually
(component dead-code detection, hook-shape conventions, ...). Skip
client/ui-wails/frontend from both analysis and coverage so neither
the rules engine nor the coverage gate trips on UI changes.

The Go side of the Wails UI (client/ui-wails/*.go, services/) is left
in scope on purpose — same Go standards as the rest of the repo.
2026-05-05 15:10:23 +02:00
Zoltán Papp
fffb9dd219 [client/ui-wails] Add Forwarding service for the exposed-services list
Surfaces the daemon's existing ForwardingRules RPC as a Wails service so
the React frontend can render the reverse-proxy / exposed-services list
in the planned dashboard.

Forwarding.List() returns one ForwardingRule per active rule with
protocol, destination port (single or range), translated address /
hostname, and translated port. The PortInfo oneof from the proto is
flattened to a `{port?: number, range?: {start, end}}` shape so TS
consumers don't have to peek at proto-internal type discriminators.

Regenerate frontend/bindings (forwarding.ts, models.ts, index.ts) so
the React side picks up the new service. peers.ts churn is a doc
comment refresh only — no API change.
2026-05-05 13:53:40 +02:00
Zoltán Papp
93275f9052 Bump github.com/wailsapp/wails/v3 to v3.0.0-alpha.84
Picks up the alpha.84 patch series. The only API change relative to
alpha.78 is a new macOS Liquid Glass effect option (NSGlassEffectView)
that NetBird does not use, so this is a drop-in dependency bump.

netbird-ui builds cleanly, go vet has no new findings, and the existing
Linux tray workaround (skip AttachWindow + OnClick on Linux) is still
required — Wails3 systemtray_linux.go's openMenu remains a "not
implemented on Linux" stub and SystemTray.applySmartDefaults still
auto-installs ToggleWindow as the click handler when a window is
attached.

The alpha CLI's transitive github.com/goreleaser/nfpm/v2 v2.44.1 is not
imported by any NetBird production binary (verified with `go list -deps`
on netbird-ui and the daemon entry points); it only ships inside the
wails3 developer CLI used for local packaging. The Snyk advisory for
nfpm therefore does not affect netbird-ui or the daemon.
2026-05-05 13:09:37 +02:00
Zoltán Papp
dd9c15072f [ci] Skip client/ui-wails in go test runs
main.go embeds frontend/dist with //go:embed, so any go-list-based test
sweep that touches the package fails at compile time before pnpm build
has populated the directory. The release pipeline runs the frontend
build via the goreleaser before-hook; the test workflows do not, and
should not, ship a Node toolchain just to compile a UI binary that has
no Go-side unit tests anyway.

Add a /client/ui-wails exclude to the test go-list filter on Linux,
Darwin and Windows.
2026-05-05 12:56:59 +02:00
Zoltán Papp
4c743bc03d Merge remote-tracking branch 'origin/main' into ui-refactor
# Conflicts:
#	client/internal/peer/status.go
#	client/proto/daemon.pb.go
#	client/proto/daemon_grpc.pb.go
#	go.mod
2026-05-05 12:49:09 +02:00
Zoltán Papp
2e61b42e92 [client/ui-wails] Slim the tray menu, move toggles to Settings page
The Fyne 1:1 tray pulled the entire daemon-config knobset (Allow SSH,
Connect on Startup, Quantum-Resistance, Lazy Connections, Block Inbound,
Notifications) into a Settings submenu — useful in a tray-only UI but
redundant now that the Wails app has a real Settings page. Drop the
submenu and route a single top-level "Settings" entry to /settings;
"Create Debug Bundle" stays at the top level for support workflows.

Side effects:
  - flipFlag and ptrBool go away with the checkbox callbacks.
  - loadConfig keeps seeding notificationsEnabled (the tray still gates
    OS toasts in onSystemEvent on it) but no longer mirrors any other
    config field.
  - Unused menu/notify constants (Allow SSH, Connect on Startup, ...,
    notifyErrorSettingsFmt) are removed from the central const block.
2026-05-05 12:19:41 +02:00
Zoltán Papp
3f8de2a149 [client/ui-wails] Hide Dock entry on macOS via LSUIElement
The legacy Fyne client and the sign-pipelines-built .pkg both run NetBird
in macOS Accessory mode (LSUIElement=1) — tray-only, no Dock entry, no
Cmd-Tab presence. The Wails build's bundled Info.plist (used by `task
darwin:package` for local development) didn't carry the flag, so the
.app bundle a developer builds locally diverged from the signed release.

Add LSUIElement to both Info.plist and Info.dev.plist so the local dev
flow matches what users see.
2026-05-05 12:03:09 +02:00
Zoltán Papp
bc609c3ae7 [client/ui-wails] Wire up enforced-update tray menu item
Surface the Fyne UI's "Download latest version" / "Install version X.Y.Z"
About-submenu entry in the Wails tray. The item starts hidden and is
revealed by onUpdateAvailable when the daemon emits EventUpdateAvailable;
opt-in updates open github.com/netbirdio/netbird/releases/latest in the
browser, enforced updates surface the in-window /update progress page
and call TriggerUpdate on the daemon.

Also lift every user-facing string and external URL in tray.go into
named const declarations at the top of the file, so future copy edits
and (eventual) localisation have a single source of truth.

The /update React route is the frontend counterpart and is owned by the
React side of the refactor.
2026-05-05 11:56:57 +02:00
Zoltán Papp
e3994d0c99 [client] Drop Mesa3D opengl32.dll, bootstrap WebView2 in Windows installers
Wails3 uses the WebKit-style WebView2 runtime instead of Fyne's OpenGL
backend, so the Mesa3D opengl32.dll payload that the Fyne build needed
for RDP/VM rendering can leave the .exe and .msi installers. Add a
WebView2 bootstrap step that probes the EdgeUpdate registry markers
(both HKLM\WOW6432Node and HKCU) and silently runs
MicrosoftEdgeWebview2Setup.exe only if the runtime is missing.

NSIS uses an inline macro adapted from Wails3's wails_tools.nsh; WiX
uses a deferred CustomAction gated on RegistrySearch properties. Both
expect the bootstrapper payload at client/MicrosoftEdgeWebview2Setup.exe,
which the sign-pipelines build step generates with `wails3 generate
webview2bootstrapper`. The matching sign-pipelines change lives in
that repo's PR.

The uninstall section keeps an unconditional `Delete opengl32.dll` so
upgrades from older Fyne builds clean up the leftover file.
2026-05-04 17:36:30 +02:00
Zoltán Papp
ba6e10cef3 [client/ui-wails] Pad macOS tray PNGs for proper menubar sizing
Wails3's macOS systray sets the NSImage size to the status bar thickness
(~22pt) on a square frame. The legacy Fyne PNGs had almost no horizontal
margin (the logo filled all 256x256), so under that explicit resize the
glyph stretched to the full menubar height — noticeably larger than
neighbouring SF Symbols-style indicators.

Pad each *-macos.png from 256x256 to 366x366 with transparent gravity:center
extent, leaving the glyph at ~70% of the rendered size. Same source PNGs,
no resampling, just more breathing room around the alpha-only template.
2026-05-04 17:12:12 +02:00
Zoltán Papp
ce53981b55 [client/ui-wails] Fix Windows manifest version format
Win32 assembly manifests require a four-part version (MAJOR.MINOR.BUILD.REVISION
per the Microsoft schema). The Wails template shipped a three-part "0.0.1",
which Windows rejects with "Activation context generation failed (...) The
value 0.0.1 of attribute version in element assemblyIdentity is invalid",
so the .exe never reaches main(). Pad to "0.0.1.0".
2026-05-04 16:20:15 +02:00
Zoltán Papp
a69037630b [client/ui-wails] Skip tray click-to-toggle on Linux
GNOME Shell + AppIndicator extension opens the attached menu on
left-click in addition to firing SNI Activate, so binding the window
toggle to the click handler made both the window and the menu pop on a
single click. The default Wails3 SystemTray.applySmartDefaults made it
worse: AttachWindow alone is enough to install ToggleWindow as the
implicit click handler, so dropping OnClick wasn't sufficient.

Mirror the legacy Fyne client: skip both AttachWindow and OnClick on
Linux and expose the main window through an explicit "Open NetBird"
menu item. Windows and macOS keep the click-to-toggle behaviour where
the OS cleanly separates left and right click.
2026-05-04 16:08:10 +02:00
Zoltán Papp
df58935cc0 [client/ui-wails] Set NetBird window and app icon on Linux
Wails3 falls back to its bundled bird logo when no Icon is supplied via
application.Options or LinuxWindow. Embed the 256x256 NetBird PNG and
wire it through both fields, plus set ProgramName=netbird so GTK's
g_set_prgname picks up the icon from the installed .desktop file. Tested
on Fedora 40 + KDE Plasma; the titlebar and taskbar now show the NetBird
logo.
2026-05-04 14:34:45 +02:00
Zoltán Papp
a1743dbf9b [client/ui-wails] Fix Fedora ayatana-appindicator package name
The RPM dependency name on Fedora is libayatana-appindicator-gtk3 (not
libayatana-appindicator3 — that's the Debian/Ubuntu spelling). Verified
with dnf install on Fedora 40.
2026-05-04 14:00:52 +02:00
Zoltán Papp
f9771de3f5 [client/ui-wails] Switch release pipelines from Fyne to Wails UI
Repoint goreleaser configs and the release workflow at client/ui-wails so
the published Linux deb/rpm, Windows binaries and macOS UI binaries are
built from the Wails source. Linux nfpm deps swap libappindicator/Fyne
GL stack for libgtk-3, libwebkit2gtk-4.1 and libayatana-appindicator3,
and the packaged .desktop file launches the binary with
WEBKIT_DISABLE_DMABUF_RENDERER=1 so RDP/VM sessions render correctly.
Frontend bindings are now committed; the release jobs add Node 20 and
pnpm 9 and run the frontend build via the goreleaser before-hook.
2026-05-04 13:00:13 +02:00
Eduard Gert
bfe19fa542 wip 2026-05-04 10:15:29 +02:00
Eduard Gert
d07f25fc49 wip 2026-05-04 10:14:41 +02:00
Eduard Gert
670b0f66ac Merge branch 'ui-refactor' into ui-refactor-ui 2026-04-30 14:57:32 +02:00
Eduard Gert
15d73a2edd Add connect toggle 2026-04-30 13:22:43 +02:00
Zoltán Papp
88a2bf582d [client] Push-based status stream for the Wails UI
Adds a SubscribeStatus gRPC RPC that pushes a fresh FullStatus snapshot
on every peer-recorder state change, replacing the Wails UI's 2-second
Status poll. The daemon's notifier already triggers on Connected /
Disconnected / Connecting / management or signal flip / address
change / peers-list change; we now coalesce those into ticks on a
buffered chan and stream the resulting snapshots over gRPC.

- Status recorder gains SubscribeToStateChanges /
  UnsubscribeFromStateChanges + a non-blocking notifyStateChange that
  drops ticks when a subscriber's 1-slot buffer is full (next snapshot
  the consumer pulls already reflects everything).
- Server.Status handler split: the snapshot composition is shared
  with the new SubscribeStatus stream handler so unary and stream
  paths return identical bytes.
- UI peers service: pollLoop replaced by statusStreamLoop. The local
  name of the existing SubscribeEvents loop is now toastStreamLoop so
  the two streams are easy to tell apart — the underlying RPC name is
  unchanged.
- Tray applyStatus skips the icon refresh when connected/lastStatus
  hasn't changed; rapid SubscribeStatus bursts during health probes
  no longer churn Shell_NotifyIcon or the log.
2026-04-30 11:45:43 +02:00
Zoltán Papp
0148d926d5 [client/ui-wails] Use original Fyne tray PNGs and drop the .ico split
The SVG-derived tray icons + multi-resolution .ico path looked correct on
disk but Wails3's Shell_NotifyIcon update never landed on the running
Windows tray — the icon stayed frozen on the .exe resource regardless of
how many times we called SetIcon. Single-PNG fed through the same path
updates correctly, so revert to the source-of-truth PNGs that ship with
the legacy Fyne UI and remove the icons_windows.go / tray_icon_*.go
split. The 6 colored tray PNGs and 6 macOS-template PNGs come from
client/ui/assets verbatim. Generation pipeline (assets/svg/) is gone.
2026-04-29 18:54:51 +02:00
Zoltán Papp
8f16a19b8f [client/ui-wails] Add windows:build:console task for log debugging
The default Windows build links the binary as a GUI subsystem app, so
stdout/stderr is detached from the launching terminal — invisible logrus
output makes tray and event-stream bugs hard to chase. Add a sibling task
that links as console subsystem and writes a separately-named binary so
the production output is preserved.

Usage:
  CGO_ENABLED=1 task windows:build:console
  bin\netbird-ui-console.exe   # logs print to the launching cmd/PowerShell
2026-04-29 16:21:45 +02:00
Zoltán Papp
504dceedf3 [client] Add Wails3 + React desktop UI scaffold
Stage 1 of the client/ui (Fyne) replacement. Adds a new client/ui-wails
module that runs on Linux/macOS/Windows from a single React + Vite +
Tailwind frontend driven by a thin gRPC services layer in Go.

- Single-module integration (no submodule): merge Wails3 into root go.mod
  with build tags !android !ios !freebsd !js so cross-compiles on those
  targets exclude the package automatically.
- Seven gRPC-bound services: Connection, Settings, Networks, Profiles,
  Debug, Update, Peers. Peers bridges Status polling and SubscribeEvents
  to the Wails event bus (netbird:status, netbird:event).
- Tray + window shell mirrors the Fyne menu 1:1 with hide-on-close,
  SIGUSR1 / Windows named-event for external "show window" triggers.
- React pages cover functional parity for Status, Settings (3 tabs),
  Networks (3 tabs), Profiles, Debug, Update, QuickActions, LoginUrl.
- SVG-sourced tray icons (12 source SVGs incl. macOS template variants)
  rasterized to PNG via task common:generate:tray:icons.
- Linux launcher sets WEBKIT_DISABLE_DMABUF_RENDERER=1 in the .desktop
  Exec= line and in task linux:run so the app renders correctly under
  RDP, VirtualBox, KVM, and bare WMs (Fluxbox/dwm) without DRM access.
2026-04-29 11:10:23 +02:00
487 changed files with 30923 additions and 9542 deletions

View File

@@ -43,5 +43,13 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined) # Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the grep then drops the broken package by path. Without -e,
# go list aborts with empty stdout and `go test` falls back to the repo
# root, which has no Go files.
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)

View File

@@ -154,7 +154,15 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined) # Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the grep then drops the broken package by path. Without -e,
# go list aborts with empty stdout and `go test` falls back to the repo
# root, which has no Go files.
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
test_client_on_docker: test_client_on_docker:
name: "Client (Docker) / Unit" name: "Client (Docker) / Unit"
@@ -214,7 +222,7 @@ jobs:
sh -c ' \ sh -c ' \
apk update; apk add --no-cache \ apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server) go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -e -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
' '
test_relay: test_relay:

View File

@@ -64,8 +64,15 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- name: Generate test script - name: Generate test script
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the Where-Object pipeline then drops the broken package by
# path. Without -e, go list aborts with empty stdout.
run: | run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } $packages = go list -e ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } | Where-Object { $_ -notmatch '/client/ui' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe" $goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1" $cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd

View File

@@ -19,8 +19,8 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/** skip: go.mod,go.sum,**/proxy/web/**,**/pnpm-lock.yaml,**/package-lock.json
golangci: golangci:
strategy: strategy:
fail-fast: false fail-fast: false
@@ -51,6 +51,15 @@ jobs:
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Stub Wails frontend bundle
# client/ui/main.go has //go:embed all:frontend/dist. The
# directory is produced by `pnpm run build` and is gitignored, so
# lint-only runs (no frontend toolchain) need a placeholder file
# for the embed pattern to match.
shell: bash
run: |
mkdir -p client/ui/frontend/dist
touch client/ui/frontend/dist/.embed-placeholder
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with: with:

View File

@@ -186,9 +186,9 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Generate windows syso arm64 - name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
@@ -349,8 +349,18 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Set up pnpm
uses: pnpm/action-setup@v3
with:
version: 9
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libgtk-3-dev libwebkit2gtk-4.1-dev libsoup-3.0-dev libayatana-appindicator3-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key - name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
@@ -370,9 +380,9 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Generate windows syso arm64 - name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
@@ -439,6 +449,14 @@ jobs:
run: go mod tidy run: go mod tidy
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Set up pnpm
uses: pnpm/action-setup@v3
with:
version: 9
- name: Run GoReleaser - name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4

1
.gitignore vendored
View File

@@ -33,3 +33,4 @@ infrastructure_files/setup-*.env
vendor/ vendor/
/netbird /netbird
client/netbird-electron/ client/netbird-electron/
management/server/types/testdata/

View File

@@ -114,6 +114,16 @@ linters:
- linters: - linters:
- staticcheck - staticcheck
text: "QF1012" text: "QF1012"
# client/ui/main.go uses //go:embed all:frontend/dist; the
# directory is populated by `pnpm build` in the release pipeline
# and missing at lint time, so the embed parses to "no matching
# files found" — surfaced by golangci-lint's typecheck pre-pass.
# Suppress just that one diagnostic; the rest of the package
# (services/, tray.go, grpc.go, ...) still gets linted normally.
- linters:
- typecheck
path: client/ui/main\.go
text: "pattern all:frontend/dist"
paths: paths:
- third_party$ - third_party$
- builtin$ - builtin$

View File

@@ -1,6 +1,11 @@
version: 2 version: 2
project_name: netbird-ui project_name: netbird-ui
before:
hooks:
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
builds: builds:
- id: netbird-ui - id: netbird-ui
dir: client/ui dir: client/ui
@@ -70,12 +75,15 @@ nfpms:
scripts: scripts:
postinstall: "release_files/ui-post-install.sh" postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/build/netbird.desktop - src: client/ui/build/linux/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png - src: client/ui/build/appicon.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
- libgtk-3-0
- libwebkit2gtk-4.1-0
- libayatana-appindicator3-1
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client UI. description: Netbird client UI.
@@ -89,12 +97,15 @@ nfpms:
scripts: scripts:
postinstall: "release_files/ui-post-install.sh" postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/build/netbird.desktop - src: client/ui/build/linux/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png - src: client/ui/build/appicon.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
- gtk3
- webkit2gtk4.1
- libayatana-appindicator-gtk3
rpm: rpm:
signature: signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}' key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'

View File

@@ -1,6 +1,11 @@
version: 2 version: 2
project_name: netbird-ui project_name: netbird-ui
before:
hooks:
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
builds: builds:
- id: netbird-ui-darwin - id: netbird-ui-darwin
dir: client/ui dir: client/ui
@@ -20,8 +25,6 @@ builds:
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
universal_binaries: universal_binaries:
- id: netbird-ui-darwin - id: netbird-ui-darwin

View File

@@ -8,7 +8,7 @@ There are many ways that you can contribute:
- Sharing use cases in slack or Reddit - Sharing use cases in slack or Reddit
- Bug fix or feature enhancement - Bug fix or feature enhancement
If you haven't already, join our slack workspace [here](https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A), we would love to discuss topics that need community contribution and enhancements to existing features. If you haven't already, join our slack workspace [here](https://docs.netbird.io/slack-url), we would love to discuss topics that need community contribution and enhancements to existing features.
## Contents ## Contents

View File

@@ -301,10 +301,11 @@ func (c *Client) PeersList() *PeerInfoArray {
peerInfos := make([]PeerInfo, len(fullStatus.Peers)) peerInfos := make([]PeerInfo, len(fullStatus.Peers))
for n, p := range fullStatus.Peers { for n, p := range fullStatus.Peers {
pi := PeerInfo{ pi := PeerInfo{
p.IP, IP: p.IP,
p.FQDN, IPv6: p.IPv6,
int(p.ConnStatus), FQDN: p.FQDN,
PeerRoutes{routes: maps.Keys(p.GetRoutes())}, ConnStatus: int(p.ConnStatus),
Routes: PeerRoutes{routes: maps.Keys(p.GetRoutes())},
} }
peerInfos[n] = pi peerInfos[n] = pi
} }
@@ -336,43 +337,84 @@ func (c *Client) Networks() *NetworkArray {
return nil return nil
} }
routesMap := routeManager.GetClientRoutesWithNetID()
v6Merged := route.V6ExitMergeSet(routesMap)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
networkArray := &NetworkArray{ networkArray := &NetworkArray{
items: make([]Network, 0), items: make([]Network, 0),
} }
resolvedDomains := c.recorder.GetResolvedDomainsStates() for id, routes := range routesMap {
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 { if len(routes) == 0 {
continue continue
} }
if _, skip := v6Merged[id]; skip {
r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue continue
} }
network := Network{
Name: string(id), network := c.buildNetwork(id, routes, routeSelector.IsSelected(id), resolvedDomains, v6Merged)
Network: netStr, if network == nil {
Peer: routePeer.FQDN, continue
Status: routePeer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
} }
networkArray.Add(network) networkArray.Add(*network)
} }
return networkArray return networkArray
} }
func (c *Client) buildNetwork(id route.NetID, routes []*route.Route, selected bool, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo, v6Merged map[route.NetID]struct{}) *Network {
r := routes[0]
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
routePeer, err := c.findBestRoutePeer(routes)
if err != nil {
log.Errorf("could not get peer info for route %s: %v", id, err)
return nil
}
network := &Network{
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: selected,
Domains: c.getNetworkDomainsFromRoute(r, resolvedDomains),
}
if route.IsV4DefaultRoute(r.Network) && route.HasV6ExitPair(id, v6Merged) {
network.Network = "0.0.0.0/0, ::/0"
}
return network
}
// findBestRoutePeer returns the peer actively routing traffic for the given
// HA route group. Falls back to the first connected peer, then the first peer.
func (c *Client) findBestRoutePeer(routes []*route.Route) (peer.State, error) {
netStr := routes[0].Network.String()
fullStatus := c.recorder.GetFullStatus()
for _, p := range fullStatus.Peers {
if _, ok := p.GetRoutes()[netStr]; ok {
return p, nil
}
}
for _, r := range routes {
p, err := c.recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if p.ConnStatus == peer.StatusConnected {
return p, nil
}
}
return c.recorder.GetPeer(routes[0].Peer)
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones // OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error { func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns() dnsServer, err := dns.GetServerDns()

View File

@@ -14,6 +14,7 @@ const (
// PeerInfo describe information about the peers. It designed for the UI usage // PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct { type PeerInfo struct {
IP string IP string
IPv6 string
FQDN string FQDN string
ConnStatus int ConnStatus int
Routes PeerRoutes Routes PeerRoutes

View File

@@ -307,6 +307,24 @@ func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block p.configInput.BlockInbound = &block
} }
// GetDisableIPv6 reads disable IPv6 setting from config file
func (p *Preferences) GetDisableIPv6() (bool, error) {
if p.configInput.DisableIPv6 != nil {
return *p.configInput.DisableIPv6, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableIPv6, err
}
// SetDisableIPv6 stores the given value and waits for commit
func (p *Preferences) SetDisableIPv6(disable bool) {
p.configInput.DisableIPv6 = &disable
}
// Commit writes out the changes to the config file // Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := profilemanager.UpdateOrCreateConfig(p.configInput) _, err := profilemanager.UpdateOrCreateConfig(p.configInput)

View File

@@ -18,9 +18,12 @@ func executeRouteToggle(id string, manager routemanager.Manager,
netID := route.NetID(id) netID := route.NetID(id)
routes := []route.NetID{netID} routes := []route.NetID{netID}
log.Debugf("%s with id: %s", operationName, id) routesMap := manager.GetClientRoutesWithNetID()
routes = route.ExpandV6ExitPairs(routes, routesMap)
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil { log.Debugf("%s with ids: %v", operationName, routes)
if err := routeOperation(routes, maps.Keys(routesMap)); err != nil {
log.Debugf("error when %s: %s", operationName, err) log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err) return fmt.Errorf("error %s: %w", operationName, err)
} }

View File

@@ -9,6 +9,7 @@ import (
"net/url" "net/url"
"regexp" "regexp"
"slices" "slices"
"strconv"
"strings" "strings"
) )
@@ -26,8 +27,9 @@ type Anonymizer struct {
} }
func DefaultAddresses() (netip.Addr, netip.Addr) { func DefaultAddresses() (netip.Addr, netip.Addr) {
// 198.51.100.0, 100:: // 198.51.100.0 (RFC 5737 TEST-NET-2), 2001:db8:ffff:: (RFC 3849 documentation, last /48)
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) // The old start 100:: (discard, RFC 6666) is now used for fake IPs on Android.
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.MustParseAddr("2001:db8:ffff::")
} }
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
@@ -48,7 +50,7 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
ip.IsLinkLocalUnicast() || ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() || ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() || ip.IsInterfaceLocalMulticast() ||
ip.IsPrivate() || (ip.Is4() && ip.IsPrivate()) ||
ip.IsUnspecified() || ip.IsUnspecified() ||
ip.IsMulticast() || ip.IsMulticast() ||
isWellKnown(ip) || isWellKnown(ip) ||
@@ -96,6 +98,11 @@ func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
} }
func (a *Anonymizer) AnonymizeIPString(ip string) string { func (a *Anonymizer) AnonymizeIPString(ip string) string {
// Handle CIDR notation (e.g. "2001:db8::/32")
if prefix, err := netip.ParsePrefix(ip); err == nil {
return a.AnonymizeIP(prefix.Addr()).String() + "/" + strconv.Itoa(prefix.Bits())
}
addr, err := netip.ParseAddr(ip) addr, err := netip.ParseAddr(ip)
if err != nil { if err != nil {
return ip return ip
@@ -150,7 +157,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
if u.Opaque != "" { if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque) host, port, err := net.SplitHostPort(u.Opaque)
if err == nil { if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port) anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
} else { } else {
anonymizedHost = a.AnonymizeDomain(u.Opaque) anonymizedHost = a.AnonymizeDomain(u.Opaque)
} }
@@ -158,7 +165,7 @@ func (a *Anonymizer) AnonymizeURI(uri string) string {
} else if u.Host != "" { } else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host) host, port, err := net.SplitHostPort(u.Host)
if err == nil { if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port) anonymizedHost = net.JoinHostPort(a.AnonymizeDomain(host), port)
} else { } else {
anonymizedHost = a.AnonymizeDomain(u.Host) anonymizedHost = a.AnonymizeDomain(u.Host)
} }

View File

@@ -13,7 +13,7 @@ import (
func TestAnonymizeIP(t *testing.T) { func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0") startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::") startIPv6 := netip.MustParseAddr("2001:db8:ffff::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6) anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct { tests := []struct {
@@ -26,9 +26,9 @@ func TestAnonymizeIP(t *testing.T) {
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"}, {"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"}, {"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"}, {"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"}, {"First Public IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
{"Second Public IPv6", "a::b", "100::1"}, {"Second Public IPv6", "a::b", "2001:db8:ffff::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"}, {"Repeated IPv6", "2607:f8b0:4005:805::200e", "2001:db8:ffff::"},
{"Private IPv6", "fe80::1", "fe80::1"}, {"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"}, {"In Range IPv4", "198.51.100.2", "198.51.100.2"},
} }
@@ -274,17 +274,27 @@ func TestAnonymizeString_IPAddresses(t *testing.T) {
{ {
name: "IPv6 Address", name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42", input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::", expect: "Access attempted from 2001:db8:ffff::",
}, },
{ {
name: "IPv6 Address with Port", name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080", input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080", expect: "Access attempted from [2001:db8:ffff::]:8080",
}, },
{ {
name: "Both IPv4 and IPv6", name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43", input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1", expect: "IPv4: 198.51.100.1 and IPv6: 2001:db8:ffff::1",
},
{
name: "STUN URI with IPv6",
input: "Connecting to stun:[2001:db8::ff00:42]:3478",
expect: "Connecting to stun:[2001:db8:ffff::]:3478",
},
{
name: "HTTPS URI with IPv6",
input: "Visit https://[2001:db8::ff00:42]:443/path",
expect: "Visit https://[2001:db8:ffff::]:443/path",
}, },
} }

View File

@@ -523,7 +523,7 @@ func parseHostnameAndCommand(args []string) error {
} }
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port) target := net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(port))
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{ c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile, KnownHostsFile: knownHostsFile,
IdentityFile: identityFile, IdentityFile: identityFile,
@@ -787,10 +787,10 @@ func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
} }
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces. // normalizeLocalHost converts "*" to "" for binding to all interfaces (dual-stack).
func normalizeLocalHost(host string) string { func normalizeLocalHost(host string) string {
if host == "*" { if host == "*" {
return "0.0.0.0" return ""
} }
return host return host
} }

View File

@@ -527,10 +527,10 @@ func TestParsePortForward(t *testing.T) {
{ {
name: "wildcard bind all interfaces", name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80", spec: "*:8080:localhost:80",
expectedLocal: "0.0.0.0:8080", expectedLocal: ":8080",
expectedRemote: "localhost:80", expectedRemote: "localhost:80",
expectError: false, expectError: false,
description: "Wildcard * should bind to all interfaces (0.0.0.0)", description: "Wildcard * should bind to all interfaces (dual-stack)",
}, },
{ {
name: "wildcard for port only", name: "wildcard for port only",

View File

@@ -20,6 +20,7 @@ import (
var ( var (
detailFlag bool detailFlag bool
ipv4Flag bool ipv4Flag bool
ipv6Flag bool
jsonFlag bool jsonFlag bool
yamlFlag bool yamlFlag bool
ipsFilter []string ipsFilter []string
@@ -45,8 +46,9 @@ func init() {
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format") statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format") statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33") statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P") statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
@@ -101,6 +103,14 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
if ipv6Flag {
ipv6 := resp.GetFullStatus().GetLocalPeerState().GetIpv6()
if ipv6 != "" {
cmd.Print(parseInterfaceIP(ipv6))
}
return nil
}
pm := profilemanager.NewProfileManager() pm := profilemanager.NewProfileManager()
var profName string var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil { if activeProf, err := pm.GetActiveProfile(); err == nil {

View File

@@ -8,6 +8,7 @@ const (
disableFirewallFlag = "disable-firewall" disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access" blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound" blockInboundFlag = "block-inbound"
disableIPv6Flag = "disable-ipv6"
) )
var ( var (
@@ -17,6 +18,7 @@ var (
disableFirewall bool disableFirewall bool
blockLANAccess bool blockLANAccess bool
blockInbound bool blockInbound bool
disableIPv6 bool
) )
func init() { func init() {
@@ -39,4 +41,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.") "This overrides any policies received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableIPv6, disableIPv6Flag, false,
"Disable IPv6 overlay. If enabled, the client won't request or use an IPv6 overlay address.")
} }

View File

@@ -435,6 +435,10 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.BlockInbound = &blockInbound req.BlockInbound = &blockInbound
} }
if cmd.Flag(disableIPv6Flag).Changed {
req.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed { if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &lazyConnEnabled req.LazyConnectionEnabled = &lazyConnEnabled
} }
@@ -552,6 +556,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.BlockInbound = &blockInbound ic.BlockInbound = &blockInbound
} }
if cmd.Flag(disableIPv6Flag).Changed {
ic.DisableIPv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed { if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled ic.LazyConnectionEnabled = &lazyConnEnabled
} }
@@ -666,6 +674,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.BlockInbound = &blockInbound loginRequest.BlockInbound = &blockInbound
} }
if cmd.Flag(disableIPv6Flag).Changed {
loginRequest.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed { if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled loginRequest.LazyConnectionEnabled = &lazyConnEnabled
} }

View File

@@ -80,6 +80,8 @@ type Options struct {
StatePath string StatePath string
// DisableClientRoutes disables the client routes // DisableClientRoutes disables the client routes
DisableClientRoutes bool DisableClientRoutes bool
// DisableIPv6 disables IPv6 overlay addressing
DisableIPv6 bool
// BlockInbound blocks all inbound connections from peers // BlockInbound blocks all inbound connections from peers
BlockInbound bool BlockInbound bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port. // WireguardPort is the port for the tunnel interface. Use 0 for a random port.
@@ -171,6 +173,7 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey, PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t, DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes, DisableClientRoutes: &opts.DisableClientRoutes,
DisableIPv6: &opts.DisableIPv6,
BlockInbound: &opts.BlockInbound, BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort, WireguardPort: opts.WireguardPort,
MTU: opts.MTU, MTU: opts.MTU,

View File

@@ -40,6 +40,7 @@ type aclManager struct {
entries aclEntries entries aclEntries
optionalEntries map[string][]entry optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
} }
@@ -51,6 +52,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil }, nil
} }
@@ -85,7 +87,11 @@ func (m *aclManager) AddPeerFiltering(
chain := chainNameInputRules chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action) ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) if m.v6 && ipsetName != "" {
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs) mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs, mangleSpecs = append(mangleSpecs,
@@ -109,6 +115,7 @@ func (m *aclManager) AddPeerFiltering(
ip: ip.String(), ip: ip.String(),
chain: chain, chain: chain,
specs: specs, specs: specs,
v6: m.v6,
}}, nil }}, nil
} }
@@ -161,6 +168,7 @@ func (m *aclManager) AddPeerFiltering(
ipsetName: ipsetName, ipsetName: ipsetName,
ip: ip.String(), ip: ip.String(),
chain: chain, chain: chain,
v6: m.v6,
} }
m.updateState() m.updateState()
@@ -413,8 +421,13 @@ func (m *aclManager) updateState() {
currentState.Lock() currentState.Lock()
defer currentState.Unlock() defer currentState.Unlock()
currentState.ACLEntries = m.entries if m.v6 {
currentState.ACLIPsetStore = m.ipsetStore currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil { if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
@@ -422,13 +435,22 @@ func (m *aclManager) updateState() {
} }
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0 // don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified() matchByIP := !ip.IsUnspecified()
if matchByIP { if matchByIP {
if ipsetName != "" { if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src") specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
} else { } else {
specs = append(specs, "-s", ip.String()) specs = append(specs, "-s", ip.String())
} }
@@ -474,6 +496,9 @@ func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{ opts := ipset.CreateOptions{
Replace: true, Replace: true,
} }
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err) return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -18,6 +18,10 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
type resetter interface {
Reset() error
}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
@@ -28,6 +32,11 @@ type Manager struct {
aclMgr *aclManager aclMgr *aclManager
router *router router *router
rawSupported bool rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
aclMgr6 *aclManager
router6 *router
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -58,9 +67,43 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
return m, nil return m, nil
} }
func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
ip6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return fmt.Errorf("init ip6tables: %w", err)
}
m.ipv6Client = ip6Client
m.router6, err = newRouter(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
func (m *Manager) hasIPv6() bool {
return m.ipv6Client != nil
}
func (m *Manager) Init(stateManager *statemanager.Manager) error { func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{ state := &ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
@@ -74,13 +117,8 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
if err := m.router.init(stateManager); err != nil { if err := m.initChains(stateManager); err != nil {
return fmt.Errorf("router init: %w", err) return err
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
} }
if err := m.initNoTrackChain(); err != nil { if err := m.initNoTrackChain(); err != nil {
@@ -103,6 +141,41 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return nil return nil
} }
// initChains initializes router and ACL chains for both address families,
// rolling back on failure.
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
type initStep struct {
name string
init func(*statemanager.Manager) error
mgr resetter
}
steps := []initStep{
{"router", m.router.init, m.router},
{"acl manager", m.aclMgr.init, m.aclMgr},
}
if m.hasIPv6() {
steps = append(steps,
initStep{"v6 router", m.router6.init, m.router6},
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
)
}
var initialized []initStep
for _, s := range steps {
if err := s.init(stateManager); err != nil {
for i := len(initialized) - 1; i >= 0; i-- {
if rerr := initialized[i].mgr.Reset(); rerr != nil {
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
}
}
return fmt.Errorf("%s init: %w", s.name, err)
}
initialized = append(initialized, s)
}
return nil
}
// AddPeerFiltering adds a rule to the firewall // AddPeerFiltering adds a rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
@@ -118,7 +191,13 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) if ip.To4() != nil {
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@@ -132,25 +211,48 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { if isIPv6RouteRule(sources, destination) {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error { func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
}
return m.aclMgr.DeletePeerRule(rule) return m.aclMgr.DeletePeerRule(rule)
} }
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule) return m.router.DeleteRouteRule(rule)
} }
@@ -166,18 +268,65 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes need NAT in both tables since resolved IPs can be
// either v4 or v6. This covers both DomainSet (modern) and the legacy
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy) if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
} }
// Reset firewall to the default state // Reset firewall to the default state
@@ -191,6 +340,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
} }
if m.hasIPv6() {
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
}
}
if err := m.aclMgr.Reset(); err != nil { if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
@@ -218,24 +376,21 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
// This is called when USPFilter wraps the native firewall, adding blanket accept // This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter. // rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering( var merr *multierror.Error
nil, if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
net.IP{0, 0, 0, 0}, merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
firewall.ProtocolALL, }
nil, if m.hasIPv6() {
nil, if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
firewall.ActionAccept, merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
"", }
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err) log.Warnf("failed to trust interface in firewalld: %v", err)
} }
return nil return nberrors.FormatErrorOrNil(merr)
} }
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
@@ -265,6 +420,12 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if rule.TranslatedAddress.Is6() {
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule) return m.router.AddDNATRule(rule)
} }
@@ -273,6 +434,9 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
@@ -281,39 +445,82 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes) var v4Prefixes, v6Prefixes []netip.Prefix
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// RemoveInboundDNAT removes an inbound DNAT rule. // RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. // RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
} }
const ( const (

View File

@@ -54,8 +54,10 @@ const (
snatSuffix = "_snat" snatSuffix = "_snat"
fwdSuffix = "_fwd" fwdSuffix = "_fwd"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipTCPHeaderMinSize = 40 ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
) )
type ruleInfo struct { type ruleInfo struct {
@@ -86,6 +88,7 @@ type router struct {
wgIface iFaceMapper wgIface iFaceMapper
legacyManagement bool legacyManagement bool
mtu uint16 mtu uint16
v6 bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
@@ -97,6 +100,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint1
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
mtu: mtu, mtu: mtu,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
@@ -186,6 +190,11 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil return ruleKey, nil
} }
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID() ruleKey := rule.ID()
@@ -392,9 +401,13 @@ func (r *router) cleanUpDefaultForwardRules() error {
// Remove jump rules from built-in chains before deleting custom chains, // Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy". // otherwise the chain deletion fails with "device or resource busy".
jumpRule := []string{"-j", chainNATOutput} if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil {
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil { return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
log.Debugf("clean OUTPUT jump rule: %v", err) } else if ok {
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
} }
for _, chainInfo := range []struct { for _, chainInfo := range []struct {
@@ -434,6 +447,12 @@ func (r *router) createContainers() error {
{chainRTRDR, tableNat}, {chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle}, {chainRTMSSCLAMP, tableMangle},
} { } {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
@@ -540,9 +559,12 @@ func (r *router) addPostroutingRules() error {
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error { func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize overhead := uint16(ipv4TCPHeaderSize)
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
// Add jump rule from FORWARD chain in mangle table to our custom chain // Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{ jumpRule := []string{
@@ -727,8 +749,13 @@ func (r *router) updateState() {
currentState.Lock() currentState.Lock()
defer currentState.Unlock() defer currentState.Unlock()
currentState.RouteRules = r.rules if r.v6 {
currentState.RouteIPsetCounter = r.ipsetCounter currentState.RouteRules6 = r.rules
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
}
if err := r.stateManager.UpdateState(currentState); err != nil { if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
@@ -856,7 +883,7 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
} }
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists { if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
} }
delete(r.rules, ruleKey+fwdSuffix) delete(r.rules, ruleKey+fwdSuffix)
@@ -883,7 +910,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
rule = append(rule, destExp...) rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL { if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto))) rule = append(rule, "-p", strings.ToLower(protoForFamily(params.Proto, r.v6)))
rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...) rule = append(rule, applyPort("--dport", params.DPort)...)
} }
@@ -900,11 +927,12 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
} }
if network.IsSet() { if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil { name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err) return nil, fmt.Errorf("create or get ipset: %w", err)
} }
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil return []string{"-m", "set", matchSet, name, direction}, nil
} }
if network.IsPrefix() { if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil return []string{flag, network.Prefix.String()}, nil
@@ -915,27 +943,23 @@ func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []
} }
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error var merr *multierror.Error
for _, prefix := range prefixes { for _, prefix := range prefixes {
// TODO: Implement IPv6 support if err := r.addPrefixToIPSet(name, prefix); err != nil {
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err)) merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
} }
} }
if merr == nil { if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes) log.Debugf("updated set %s with prefixes %v", name, prefixes)
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
if _, exists := r.rules[ruleID]; exists { if _, exists := r.rules[ruleID]; exists {
return nil return nil
@@ -943,12 +967,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
dnatRule := []string{ dnatRule := []string{
"-i", r.wgIface.Name(), "-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)), "-p", strings.ToLower(protoForFamily(protocol, r.v6)),
"--dport", strconv.Itoa(int(sourcePort)), "--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(), "-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL", "-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT", "-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)), "--to-destination", ":" + strconv.Itoa(int(translatedPort)),
} }
ruleInfo := ruleInfo{ ruleInfo := ruleInfo{
@@ -967,8 +991,8 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
} }
// RemoveInboundDNAT removes an inbound DNAT rule. // RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
if dnatRule, exists := r.rules[ruleID]; exists { if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
@@ -1013,8 +1037,8 @@ func (r *router) ensureNATOutputChain() error {
} }
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
if _, exists := r.rules[ruleID]; exists { if _, exists := r.rules[ruleID]; exists {
return nil return nil
@@ -1025,11 +1049,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
} }
dnatRule := []string{ dnatRule := []string{
"-p", strings.ToLower(string(protocol)), "-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
"--dport", strconv.Itoa(int(sourcePort)), "--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(), "-d", localAddr.String(),
"-j", "DNAT", "-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)), "--to-destination", ":" + strconv.Itoa(int(translatedPort)),
} }
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil { if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
@@ -1042,8 +1066,8 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
} }
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. // RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
if dnatRule, exists := r.rules[ruleID]; exists { if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil { if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
@@ -1076,10 +1100,22 @@ func applyPort(flag string, port *firewall.Port) []string {
return []string{flag, strconv.Itoa(int(port.Values[0]))} return []string{flag, strconv.Itoa(int(port.Values[0]))}
} }
// ipsetName returns the ipset name, suffixed with "-v6" for the v6 router
// to avoid collisions since ipsets are global in the kernel.
func (r *router) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *router) createIPSet(name string) error { func (r *router) createIPSet(name string) error {
opts := ipset.CreateOptions{ opts := ipset.CreateOptions{
Replace: true, Replace: true,
} }
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil { if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err) return fmt.Errorf("create ipset %s: %w", name, err)

View File

@@ -9,6 +9,7 @@ type Rule struct {
mangleSpecs []string mangleSpecs []string
ip string ip string
chain string chain string
v6 bool
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id

View File

@@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"sync" "sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -32,6 +34,12 @@ type ShutdownState struct {
ACLEntries aclEntries `json:"acl_entries,omitempty"` ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
// IPv6 counterparts
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -62,6 +70,28 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore ipt.aclMgr.ipsetStore = s.ACLIPsetStore
} }
// Clean up v6 state even if the current run has no IPv6.
// The previous run may have left ip6tables rules behind.
if !ipt.hasIPv6() {
if err := ipt.createIPv6Components(s.InterfaceState, mtu); err != nil {
log.Warnf("failed to create v6 components for cleanup: %v", err)
}
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.router6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
}
}
if err := ipt.Close(nil); err != nil { if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err) return fmt.Errorf("reset iptables manager: %w", err)
} }

View File

@@ -1,6 +1,7 @@
package manager package manager
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -11,6 +12,10 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// ErrIPv6NotInitialized is returned when an IPv6 address is passed to a firewall
// method but the IPv6 firewall components were not initialized.
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
const ( const (
ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t" ForwardingFormat = "netbird-fwd-%s-%t"
@@ -164,18 +169,16 @@ type Manager interface {
UpdateSet(hash Set, prefixes []netip.Prefix) error UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error AddInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule // RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. AddOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. // RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only. RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, originalPort, translatedPort uint16) error
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication. // This prevents conntrack from interfering with WireGuard proxy communication.

View File

@@ -1,6 +1,8 @@
package manager package manager
import ( import (
"net/netip"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -10,6 +12,10 @@ type RouterPair struct {
Destination Network Destination Network
Masquerade bool Masquerade bool
Inverse bool Inverse bool
// Dynamic indicates the route is domain-based. NAT rules for dynamic
// routes are duplicated to the v6 table so that resolved AAAA records
// are masqueraded correctly.
Dynamic bool
} }
func GetInversePair(pair RouterPair) RouterPair { func GetInversePair(pair RouterPair) RouterPair {
@@ -20,5 +26,17 @@ func GetInversePair(pair RouterPair) RouterPair {
Destination: pair.Source, Destination: pair.Source,
Masquerade: pair.Masquerade, Masquerade: pair.Masquerade,
Inverse: true, Inverse: true,
Dynamic: pair.Dynamic,
} }
} }
// ToV6NatPair creates a v6 counterpart of a v4 NAT pair with `::/0` source
// and, for prefix destinations, `::/0` destination.
func ToV6NatPair(pair RouterPair) RouterPair {
v6 := pair
v6.Source = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
if v6.Destination.IsPrefix() {
v6.Destination = Network{Prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
return v6
}

View File

@@ -33,15 +33,12 @@ const (
const flushError = "flush: %w" const flushError = "flush: %w"
var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
type AclManager struct { type AclManager struct {
rConn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn sConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
af addrFamily
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
@@ -67,6 +64,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
wgIface: wgIface, wgIface: wgIface,
workTable: table, workTable: table,
routingFwChainName: routingFwChainName, routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
@@ -145,7 +143,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
} }
if _, ok := ips[r.ip.String()]; ok { if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
if err != nil { if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
} }
@@ -254,11 +252,11 @@ func (m *AclManager) addIOFiltering(
expressions = append(expressions, &expr.Payload{ expressions = append(expressions, &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9), Offset: m.af.protoOffset,
Len: uint32(1), Len: uint32(1),
}) })
protoData, err := protoToInt(proto) protoData, err := m.af.protoNum(proto)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err) return nil, fmt.Errorf("convert protocol to number: %v", err)
} }
@@ -270,19 +268,16 @@ func (m *AclManager) addIOFiltering(
}) })
} }
rawIP := ip.To4() rawIP := ipToBytes(ip, m.af)
// check if rawIP contains zeroed IPv4 0.0.0.0 value // check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition // in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) { if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
// source address position
addrOffset := uint32(12)
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: addrOffset, Offset: m.af.srcAddrOffset,
Len: 4, Len: m.af.addrLen,
}, },
) )
// add individual IP for match if no ipset defined // add individual IP for match if no ipset defined
@@ -587,7 +582,7 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ip.To4() rawIP := ipToBytes(ip, m.af)
if err != nil { if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err) return nil, fmt.Errorf("get set name: %v", err)
@@ -619,7 +614,7 @@ func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Se
Name: name, Name: name,
Table: table, Table: table,
Dynamic: true, Dynamic: true,
KeyType: nftables.TypeIPAddr, KeyType: m.af.setKeyType,
} }
if err := m.rConn.AddSet(ipset, nil); err != nil { if err := m.rConn.AddSet(ipset, nil); err != nil {
@@ -707,15 +702,12 @@ func ifname(n string) []byte {
return b return b
} }
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol) // ipToBytes converts net.IP to the correct byte length for the address family.
func ipToBytes(ip net.IP, af addrFamily) []byte {
if af.addrLen == 4 {
return ip.To4()
}
return ip.To16()
} }

View File

@@ -0,0 +1,81 @@
package nftables
import (
"fmt"
"net"
"github.com/google/nftables"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
// afIPv4 defines IPv4 header layout and nftables types.
afIPv4 = addrFamily{
protoOffset: 9,
srcAddrOffset: 12,
dstAddrOffset: 16,
addrLen: net.IPv4len,
totalBits: 8 * net.IPv4len,
setKeyType: nftables.TypeIPAddr,
tableFamily: nftables.TableFamilyIPv4,
icmpProto: unix.IPPROTO_ICMP,
}
// afIPv6 defines IPv6 header layout and nftables types.
afIPv6 = addrFamily{
protoOffset: 6,
srcAddrOffset: 8,
dstAddrOffset: 24,
addrLen: net.IPv6len,
totalBits: 8 * net.IPv6len,
setKeyType: nftables.TypeIP6Addr,
tableFamily: nftables.TableFamilyIPv6,
icmpProto: unix.IPPROTO_ICMPV6,
}
)
// addrFamily holds protocol-specific constants for nftables expression building.
type addrFamily struct {
// protoOffset is the IP header offset for the protocol/next-header field (9 for v4, 6 for v6)
protoOffset uint32
// srcAddrOffset is the IP header offset for the source address (12 for v4, 8 for v6)
srcAddrOffset uint32
// dstAddrOffset is the IP header offset for the destination address (16 for v4, 24 for v6)
dstAddrOffset uint32
// addrLen is the byte length of addresses (4 for v4, 16 for v6)
addrLen uint32
// totalBits is the address size in bits (32 for v4, 128 for v6)
totalBits int
// setKeyType is the nftables set data type for addresses
setKeyType nftables.SetDatatype
// tableFamily is the nftables table family
tableFamily nftables.TableFamily
// icmpProto is the ICMP protocol number for this family (1 for v4, 58 for v6)
icmpProto uint8
}
// familyForAddr returns the address family for the given IP.
func familyForAddr(is4 bool) addrFamily {
if is4 {
return afIPv4
}
return afIPv6
}
// protoNum converts a firewall protocol to the IP protocol number,
// using the correct ICMP variant for the address family.
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return af.icmpProto, nil
case firewall.ProtocolALL:
return 0, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}

View File

@@ -0,0 +1,76 @@
//go:build linux
package nftables
import (
"os"
"sync/atomic"
"testing"
"time"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
// TestExternalChainMonitorRootIntegration verifies that adding a new chain
// in an external (non-netbird) filter table triggers the reconciler.
// Requires CAP_NET_ADMIN; skip otherwise.
func TestExternalChainMonitorRootIntegration(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("root required")
}
calls := make(chan struct{}, 8)
var count atomic.Int32
rec := &countingReconciler{calls: calls, count: &count}
m := newExternalChainMonitor(rec)
m.start()
t.Cleanup(m.stop)
// Give the netlink subscription a moment to register.
time.Sleep(200 * time.Millisecond)
conn := &nftables.Conn{}
table := conn.AddTable(&nftables.Table{
Name: "nbmon_integration_test",
Family: nftables.TableFamilyINet,
})
t.Cleanup(func() {
cleanup := &nftables.Conn{}
cleanup.DelTable(table)
_ = cleanup.Flush()
})
chain := conn.AddChain(&nftables.Chain{
Name: "filter_INPUT",
Table: table,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
})
_ = chain
require.NoError(t, conn.Flush(), "create external test chain")
select {
case <-calls:
// success
case <-time.After(3 * time.Second):
t.Fatalf("reconcile was not invoked after creating an external chain")
}
require.GreaterOrEqual(t, count.Load(), int32(1))
}
type countingReconciler struct {
calls chan struct{}
count *atomic.Int32
}
func (c *countingReconciler) reconcileExternalChains() error {
c.count.Add(1)
select {
case c.calls <- struct{}{}:
default:
}
return nil
}

View File

@@ -0,0 +1,199 @@
package nftables
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
)
const (
externalMonitorReconcileDelay = 500 * time.Millisecond
externalMonitorInitInterval = 5 * time.Second
externalMonitorMaxInterval = 5 * time.Minute
externalMonitorRandomization = 0.5
)
// externalChainReconciler re-applies passthrough accept rules to external
// nftables chains. Implementations must be safe to call from the monitor
// goroutine; the Manager locks its mutex internally.
type externalChainReconciler interface {
reconcileExternalChains() error
}
// externalChainMonitor watches nftables netlink events and triggers a
// reconcile when a new table or chain appears (e.g. after
// `firewall-cmd --reload`). Netlink errors trigger exponential-backoff
// reconnect.
type externalChainMonitor struct {
reconciler externalChainReconciler
mu sync.Mutex
cancel context.CancelFunc
done chan struct{}
}
func newExternalChainMonitor(r externalChainReconciler) *externalChainMonitor {
return &externalChainMonitor{reconciler: r}
}
func (m *externalChainMonitor) start() {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
m.done = make(chan struct{})
go m.run(ctx)
}
func (m *externalChainMonitor) stop() {
m.mu.Lock()
cancel := m.cancel
done := m.done
m.cancel = nil
m.done = nil
m.mu.Unlock()
if cancel == nil {
return
}
cancel()
<-done
}
func (m *externalChainMonitor) run(ctx context.Context) {
defer close(m.done)
bo := &backoff.ExponentialBackOff{
InitialInterval: externalMonitorInitInterval,
RandomizationFactor: externalMonitorRandomization,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: externalMonitorMaxInterval,
MaxElapsedTime: 0,
Clock: backoff.SystemClock,
}
bo.Reset()
for ctx.Err() == nil {
err := m.watch(ctx)
if ctx.Err() != nil {
return
}
delay := bo.NextBackOff()
log.Warnf("external chain monitor: %v, reconnecting in %s", err, delay)
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
}
}
func (m *externalChainMonitor) watch(ctx context.Context) error {
events, closeMon, err := m.subscribe()
if err != nil {
return err
}
defer closeMon()
debounce := time.NewTimer(time.Hour)
if !debounce.Stop() {
<-debounce.C
}
defer debounce.Stop()
pending := false
for {
select {
case <-ctx.Done():
return nil
case <-debounce.C:
pending = false
m.reconcile()
case ev, ok := <-events:
if !ok {
return errors.New("monitor channel closed")
}
if ev.Error != nil {
return fmt.Errorf("monitor event: %w", ev.Error)
}
if !isRelevantMonitorEvent(ev) {
continue
}
resetDebounce(debounce, pending)
pending = true
}
}
}
func (m *externalChainMonitor) subscribe() (chan *nftables.MonitorEvent, func(), error) {
conn := &nftables.Conn{}
mon := nftables.NewMonitor(
nftables.WithMonitorAction(nftables.MonitorActionNew),
nftables.WithMonitorObject(nftables.MonitorObjectChains|nftables.MonitorObjectTables),
)
events, err := conn.AddMonitor(mon)
if err != nil {
return nil, nil, fmt.Errorf("add netlink monitor: %w", err)
}
return events, func() { _ = mon.Close() }, nil
}
// resetDebounce reschedules a pending debounce timer without leaking a stale
// fire on its channel. pending must reflect whether the timer is armed.
func resetDebounce(t *time.Timer, pending bool) {
if pending && !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(externalMonitorReconcileDelay)
}
func (m *externalChainMonitor) reconcile() {
if err := m.reconciler.reconcileExternalChains(); err != nil {
log.Warnf("reconcile external chain rules: %v", err)
}
}
// isRelevantMonitorEvent returns true for table/chain creation events on
// families we care about. The reconciler filters to actual external filter
// chains.
func isRelevantMonitorEvent(ev *nftables.MonitorEvent) bool {
switch ev.Type {
case nftables.MonitorEventTypeNewChain:
chain, ok := ev.Data.(*nftables.Chain)
if !ok || chain == nil || chain.Table == nil {
return false
}
return isMonitoredFamily(chain.Table.Family)
case nftables.MonitorEventTypeNewTable:
table, ok := ev.Data.(*nftables.Table)
if !ok || table == nil {
return false
}
return isMonitoredFamily(table.Family)
}
return false
}
func isMonitoredFamily(family nftables.TableFamily) bool {
switch family {
case nftables.TableFamilyIPv4, nftables.TableFamilyIPv6, nftables.TableFamilyINet:
return true
}
return false
}

View File

@@ -0,0 +1,137 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/assert"
)
func TestIsMonitoredFamily(t *testing.T) {
tests := []struct {
family nftables.TableFamily
want bool
}{
{nftables.TableFamilyIPv4, true},
{nftables.TableFamilyIPv6, true},
{nftables.TableFamilyINet, true},
{nftables.TableFamilyARP, false},
{nftables.TableFamilyBridge, false},
{nftables.TableFamilyNetdev, false},
{nftables.TableFamilyUnspecified, false},
}
for _, tc := range tests {
assert.Equal(t, tc.want, isMonitoredFamily(tc.family), "family=%d", tc.family)
}
}
func TestIsRelevantMonitorEvent(t *testing.T) {
inetTable := &nftables.Table{Name: "firewalld", Family: nftables.TableFamilyINet}
ipTable := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4}
arpTable := &nftables.Table{Name: "arp", Family: nftables.TableFamilyARP}
tests := []struct {
name string
ev *nftables.MonitorEvent
want bool
}{
{
name: "new chain in inet firewalld",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
},
want: true,
},
{
name: "new chain in ip filter",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "INPUT", Table: ipTable},
},
want: true,
},
{
name: "new chain in unwatched arp family",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "x", Table: arpTable},
},
want: false,
},
{
name: "new table inet",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewTable,
Data: inetTable,
},
want: true,
},
{
name: "del chain (we only act on new)",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeDelChain,
Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable},
},
want: false,
},
{
name: "chain with nil table",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: &nftables.Chain{Name: "x"},
},
want: false,
},
{
name: "nil data",
ev: &nftables.MonitorEvent{
Type: nftables.MonitorEventTypeNewChain,
Data: (*nftables.Chain)(nil),
},
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, isRelevantMonitorEvent(tc.ev))
})
}
}
// fakeReconciler records reconcile invocations for debounce tests.
type fakeReconciler struct {
calls chan struct{}
}
func (f *fakeReconciler) reconcileExternalChains() error {
f.calls <- struct{}{}
return nil
}
func TestExternalChainMonitorStopWithoutStart(t *testing.T) {
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
// Must not panic or block.
m.stop()
}
func TestExternalChainMonitorDoubleStart(t *testing.T) {
// start() twice should be a no-op; stop() cleans up once.
// We avoid exercising the netlink watch loop here because it needs root.
m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)})
// Replace run with a stub that just waits for cancel, so start() stays
// deterministic without opening a netlink socket.
origDone := make(chan struct{})
m.done = origDone
m.cancel = func() { close(origDone) }
// Second start should be a no-op (cancel already set).
m.start()
assert.NotNil(t, m.cancel)
m.stop()
assert.Nil(t, m.cancel)
assert.Nil(t, m.done)
}

View File

@@ -11,9 +11,11 @@ import (
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld" "github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -49,10 +51,17 @@ type Manager struct {
rConn *nftables.Conn rConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
router *router router *router
aclManager *AclManager aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
notrackOutputChain *nftables.Chain notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain notrackPreroutingChain *nftables.Chain
extMonitor *externalChainMonitor
} }
// Create nftables firewall manager // Create nftables firewall manager
@@ -62,7 +71,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
wgIface: wgIface, wgIface: wgIface,
} }
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} tableName := getTableName()
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
var err error var err error
m.router, err = newRouter(workTable, wgIface, mtu) m.router, err = newRouter(workTable, wgIface, mtu)
@@ -75,35 +85,137 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
if wgIface.Address().HasIPv6() {
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
}
}
m.extMonitor = newExternalChainMonitor(m)
return m, nil return m, nil
} }
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.router6, err = newRouter(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.router6 != nil
}
func (m *Manager) initIPv6() error {
workTable6, err := m.createWorkTableFamily(nftables.TableFamilyIPv6)
if err != nil {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
}
return nil
}
// Init nftables firewall manager // Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error { func (m *Manager) Init(stateManager *statemanager.Manager) error {
if err := m.initFirewall(); err != nil {
return err
}
m.persistState(stateManager)
// Start after initFirewall has installed the baseline external-chain
// accept rules. start() is idempotent across Init/Close/Init cycles.
m.extMonitor.start()
return nil
}
// reconcileExternalChains re-applies passthrough accept rules to external
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
// tables or chains appear (e.g. after firewalld reloads).
func (m *Manager) reconcileExternalChains() error {
m.mutex.Lock()
defer m.mutex.Unlock()
var merr *multierror.Error
if m.router != nil {
if err := m.router.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
}
}
if m.hasIPv6() {
if err := m.router6.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (m *Manager) initFirewall() (err error) {
workTable, err := m.createWorkTable() workTable, err := m.createWorkTable()
if err != nil { if err != nil {
return fmt.Errorf("create work table: %w", err) return fmt.Errorf("create work table: %w", err)
} }
defer func() {
if err != nil {
m.rollbackInit()
}
}()
if err := m.router.init(workTable); err != nil { if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err) return fmt.Errorf("router init: %w", err)
} }
if err := m.aclManager.init(workTable); err != nil { if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err) return fmt.Errorf("acl manager init: %w", err)
} }
if m.hasIPv6() {
if err := m.initIPv6(); err != nil {
// Peer has a v6 address: v6 firewall MUST work or we risk fail-open.
return fmt.Errorf("init IPv6 firewall (required because peer has IPv6 address): %w", err)
}
}
if err := m.initNoTrackChains(workTable); err != nil { if err := m.initNoTrackChains(workTable); err != nil {
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err) log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
} }
return nil
}
// persistState saves the current interface state for potential recreation on restart.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
func (m *Manager) persistState(stateManager *statemanager.Manager) {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{ if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
@@ -114,14 +226,29 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
} }
// persist early
go func() { go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
}() }()
}
return nil // rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
func (m *Manager) rollbackInit() {
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router: %v", err)
}
}
if err := m.cleanupNetbirdTables(); err != nil {
log.Warnf("cleanup tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
log.Warnf("flush: %v", err)
}
} }
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall
@@ -140,12 +267,14 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
rawIP := ip.To4() if ip.To4() != nil {
if rawIP == nil { return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
@@ -159,8 +288,11 @@ func (m *Manager) AddRouteFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() { if isIPv6RouteRule(sources, destination) {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String()) if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -171,15 +303,66 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
}
return m.aclManager.DeletePeerRule(rule) return m.aclManager.DeletePeerRule(rule)
} }
// DeleteRouteRule deletes a routing rule func isIPv6Rule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
// router; the cached maps are normally authoritative, so the kernel is only
// consulted when neither map knows about the rule.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule) id := rule.ID()
r, err := m.routerForRuleID(id, (*router).hasRule)
if err != nil {
return err
}
return r.DeleteRouteRule(rule)
}
// routerForRuleID picks the router holding the rule with the given id, using
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
// from the kernel once and re-checks before falling back to the v4 router.
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
if has(m.router, id) {
return m.router, nil
}
if m.hasIPv6() && has(m.router6, id) {
return m.router6, nil
}
if !m.hasIPv6() {
return m.router, nil
}
if err := m.router.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v4 rules: %w", err)
}
if err := m.router6.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v6 rules: %w", err)
}
if has(m.router6, id) && !has(m.router, id) {
return m.router6, nil
}
return m.router, nil
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
@@ -194,19 +377,70 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
return err
}
// Dynamic routes need NAT in both tables since resolved IPs can be
// either v4 or v6. This covers both DomainSet (modern) and the legacy
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
// On v6 failure we keep the v4 NAT rule rather than rolling back: half
// connectivity is better than none, and RemoveNatRule is content-keyed
// so the eventual cleanup still works.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
return nil
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) if pair.Destination.IsPrefix() && pair.Destination.Prefix.Addr().Is6() {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic. // AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept // This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter. // rules so that packet filtering is handled in userspace instead of by netfilter.
//
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
// which doesn't override DROP rules in external tables (e.g. firewalld).
// Should add passthrough rules to external chains (like the native mode router's
// addExternalChainsRules does) for both the netbird table family and inet tables.
// The netbird table itself is fine (routing chains already exist there), but
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -214,6 +448,11 @@ func (m *Manager) AllowNetbird() error {
if err := m.aclManager.createDefaultAllowRules(); err != nil { if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err) return fmt.Errorf("create default allow rules: %w", err)
} }
if m.hasIPv6() {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err) return fmt.Errorf("flush allow input netbird rules: %w", err)
} }
@@ -227,31 +466,47 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management // SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy) if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
} }
// Close closes the firewall manager // Close closes the firewall manager
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.extMonitor.stop()
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var merr *multierror.Error
if err := m.router.Reset(); err != nil { if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err) merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
}
} }
if err := m.cleanupNetbirdTables(); err != nil { if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err) merr = multierror.Append(merr, fmt.Errorf("cleanup netbird tables: %v", err))
} }
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err) merr = multierror.Append(merr, fmt.Errorf(flushError, err))
} }
if err := stateManager.DeleteState(&ShutdownState{}); err != nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err) merr = multierror.Append(merr, fmt.Errorf("delete state: %v", err))
} }
return nil return nberrors.FormatErrorOrNil(merr)
} }
func (m *Manager) cleanupNetbirdTables() error { func (m *Manager) cleanupNetbirdTables() error {
@@ -300,6 +555,12 @@ func (m *Manager) Flush() error {
return err return err
} }
if m.hasIPv6() {
if err := m.aclManager6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
if err := m.refreshNoTrackChains(); err != nil { if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err) log.Errorf("failed to refresh notrack chains: %v", err)
} }
@@ -312,6 +573,12 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if rule.TranslatedAddress.Is6() {
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule) return m.router.AddDNATRule(rule)
} }
@@ -320,7 +587,11 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule) r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
if err != nil {
return err
}
return r.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes // UpdateSet updates the set with the given prefixes
@@ -328,39 +599,82 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes) var v4Prefixes, v6Prefixes []netip.Prefix
for _, p := range prefixes {
if p.Addr().Is6() {
v6Prefixes = append(v6Prefixes, p)
} else {
v4Prefixes = append(v4Prefixes, p)
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// RemoveInboundDNAT removes an inbound DNAT rule. // RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. // RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) if localAddr.Is6() {
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
} }
const ( const (
@@ -534,7 +848,11 @@ func (m *Manager) refreshNoTrackChains() error {
} }
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) return m.createWorkTableFamily(nftables.TableFamilyIPv4)
}
func (m *Manager) createWorkTableFamily(family nftables.TableFamily) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
@@ -546,7 +864,7 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: family})
err = m.rConn.Flush() err = m.rConn.Flush()
return table, err return table, err
} }

View File

@@ -383,10 +383,138 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
err = manager.AddNatRule(pair) err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule") require.NoError(t, err, "failed to add NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("100.96.0.2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "failed to add DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "failed to delete DNAT rule")
})
stdout, stderr = runIptablesSave(t) stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
} }
func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
for _, bin := range []string{"ip6tables", "ip6tables-save", "iptables-save"} {
if _, err := exec.LookPath(bin); err != nil {
t.Skipf("%s not available on this system: %v", bin, err)
}
}
// Seed ip6 tables in the nft backend. Docker may not create them.
seedIp6tables(t)
ifaceMockV6 := &iFaceMock{
NameFunc: func() string { return "wt-test" },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
manager, err := Create(ifaceMockV6, iface.DefaultMTU)
require.NoError(t, err, "create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil), "close manager")
stdout, stderr := runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "add v6 route filtering rule")
err = manager.AddNatRule(fw.RouterPair{
Source: fw.Network{Prefix: netip.MustParsePrefix("fd00::/64")},
Destination: fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
Masquerade: true,
})
require.NoError(t, err, "add v6 NAT rule")
dnatRule, err := manager.AddDNATRule(fw.ForwardRule{
Protocol: fw.ProtocolTCP,
DestinationPort: fw.Port{Values: []uint16{8080}},
TranslatedAddress: netip.MustParseAddr("fd00::2"),
TranslatedPort: fw.Port{Values: []uint16{80}},
})
require.NoError(t, err, "add v6 DNAT rule")
t.Cleanup(func() {
require.NoError(t, manager.DeleteDNATRule(dnatRule), "delete v6 DNAT rule")
})
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
stdout, stderr = runIp6tablesSave(t)
verifyIp6tablesOutput(t, stdout, stderr)
}
func seedIp6tables(t *testing.T) {
t.Helper()
for _, tc := range []struct{ table, chain string }{
{"filter", "FORWARD"},
{"nat", "POSTROUTING"},
{"mangle", "FORWARD"},
} {
add := exec.Command("ip6tables", "-t", tc.table, "-A", tc.chain, "-j", "ACCEPT")
require.NoError(t, add.Run(), "seed ip6tables -t %s", tc.table)
del := exec.Command("ip6tables", "-t", tc.table, "-D", tc.chain, "-j", "ACCEPT")
require.NoError(t, del.Run(), "unseed ip6tables -t %s", tc.table)
}
}
func runIp6tablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("ip6tables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "ip6tables-save failed")
return stdout.String(), stderr.String()
}
func verifyIp6tablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
for _, msg := range []string{
"Table `nat' is incompatible",
"Table `mangle' is incompatible",
"Table `filter' is incompatible",
} {
require.NotContains(t, stdout, msg,
"ip6tables-save stdout reports incompatibility: %s", stdout)
require.NotContains(t, stderr, msg,
"ip6tables-save stderr reports incompatibility: %s", stderr)
}
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) { func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this system") t.Skip("nftables not supported on this system")

View File

@@ -50,8 +50,10 @@ const (
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
snatSuffix = "_snat" snatSuffix = "_snat"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipTCPHeaderMinSize = 40 ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// maxPrefixesSet 1638 prefixes start to fail, taking some margin // maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500 maxPrefixesSet = 1500
@@ -76,6 +78,7 @@ type router struct {
rules map[string]*nftables.Rule rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set] ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
af addrFamily
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool legacyManagement bool
@@ -88,6 +91,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
workTable: workTable, workTable: workTable,
chains: make(map[string]*nftables.Chain), chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu, mtu: mtu,
@@ -150,7 +154,7 @@ func (r *router) Reset() error {
func (r *router) removeNatPreroutingRules() error { func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{ table := &nftables.Table{
Name: tableNat, Name: tableNat,
Family: nftables.TableFamilyIPv4, Family: r.af.tableFamily,
} }
chain := &nftables.Chain{ chain := &nftables.Chain{
Name: chainNameNatPrerouting, Name: chainNameNatPrerouting,
@@ -183,7 +187,7 @@ func (r *router) removeNatPreroutingRules() error {
} }
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
if err != nil { if err != nil {
return nil, fmt.Errorf("list tables: %w", err) return nil, fmt.Errorf("list tables: %w", err)
} }
@@ -419,7 +423,7 @@ func (r *router) AddRouteFiltering(
// Handle protocol // Handle protocol
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
protoNum, err := protoToInt(proto) protoNum, err := r.af.protoNum(proto)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err) return nil, fmt.Errorf("convert protocol to number: %w", err)
} }
@@ -479,7 +483,24 @@ func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bo
return nil, fmt.Errorf("create or get ipset: %w", err) return nil, fmt.Errorf("create or get ipset: %w", err)
} }
return getIpSetExprs(ref, isSource) return r.getIpSetExprs(ref, isSource)
}
func (r *router) iptablesProto() iptables.Protocol {
if r.af.tableFamily == nftables.TableFamilyIPv6 {
return iptables.ProtocolIPv6
}
return iptables.ProtocolIPv4
}
func (r *router) hasRule(id string) bool {
_, ok := r.rules[id]
return ok
}
func (r *router) hasDNATRule(id string) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -528,10 +549,10 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
Table: r.workTable, Table: r.workTable,
// required for prefixes // required for prefixes
Interval: true, Interval: true,
KeyType: nftables.TypeIPAddr, KeyType: r.af.setKeyType,
} }
elements := convertPrefixesToSet(prefixes) elements := r.convertPrefixesToSet(prefixes)
nElements := len(elements) nElements := len(elements)
maxElements := maxPrefixesSet * 2 maxElements := maxPrefixesSet * 2
@@ -564,23 +585,17 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
return nfset, nil return nfset, nil
} }
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement { func (r *router) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement var elements []nftables.SetElement
for _, prefix := range prefixes { for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes // nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr() firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next() lastIP := calculateLastIP(prefix).Next()
elements = append(elements, elements = append(elements,
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 // the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, // nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()}, nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
) )
@@ -590,10 +605,20 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
// calculateLastIP determines the last IP in a given prefix. // calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr { func calculateLastIP(prefix netip.Prefix) netip.Addr {
hostMask := ^uint32(0) >> prefix.Masked().Bits() masked := prefix.Masked()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask if masked.Addr().Is4() {
hostMask := ^uint32(0) >> masked.Bits()
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
return netip.AddrFrom4(uint32ToBytes(lastIP)) // IPv6: set host bits to all 1s
b := masked.Addr().As16()
bits := masked.Bits()
for i := bits; i < 128; i++ {
b[i/8] |= 1 << (7 - i%8)
}
return netip.AddrFrom16(b)
} }
// Utility function to convert netip.Addr to uint32. // Utility function to convert netip.Addr to uint32.
@@ -845,9 +870,16 @@ func (r *router) addPostroutingRules() {
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error { func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize overhead := uint16(ipv4TCPHeaderSize)
if r.af.tableFamily == nftables.TableFamilyIPv6 {
overhead = ipv6TCPHeaderSize
}
if r.mtu <= overhead {
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
return nil
}
mss := r.mtu - overhead
exprsOut := []expr.Any{ exprsOut := []expr.Any{
&expr.Meta{ &expr.Meta{
@@ -1054,17 +1086,22 @@ func (r *router) acceptFilterTableRules() error {
log.Debugf("Used %s to add accept forward and input rules", fw) log.Debugf("Used %s to add accept forward and input rules", fw)
}() }()
// Try iptables first and fallback to nftables if iptables is not available // Try iptables first and fallback to nftables if iptables is not available.
ipt, err := iptables.New() // Use the correct protocol (iptables vs ip6tables) for the address family.
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil { if err != nil {
// iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables" fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable) return r.acceptFilterRulesNftables(r.filterTable)
} }
return r.acceptFilterRulesIptables(ipt) if err := r.acceptFilterRulesIptables(ipt); err != nil {
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
return nil
} }
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1135,83 +1172,122 @@ func (r *router) acceptExternalChainsRules() error {
} }
intf := ifname(r.wgIface.Name()) intf := ifname(r.wgIface.Name())
for _, chain := range chains { for _, chain := range chains {
if chain.Hooknum == nil { r.applyExternalChainAccept(chain, intf)
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err) return fmt.Errorf("flush external chain rules: %w", err)
} }
return nil return nil
} }
func (r *router) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
return
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) { func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
iifRule := &nftables.Rule{ existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
r.insertForwardIifRule(chain, intf, existing)
r.insertForwardOifEstablishedRule(chain, intf, existing)
}
func (r *router) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleIif] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table, Table: chain.Table,
Chain: chain, Chain: chain,
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{}, &expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept}, &expr.Verdict{Kind: expr.VerdictAccept},
}, },
UserData: []byte(userDataAcceptForwardRuleIif), UserData: []byte(userDataAcceptForwardRuleIif),
} })
r.conn.InsertRule(iifRule) }
oifExprs := []expr.Any{ func (r *router) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, if existing[userDataAcceptForwardRuleOif] {
&expr.Cmp{ return
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
} }
oifRule := &nftables.Rule{ exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table, Table: chain.Table,
Chain: chain, Chain: chain,
Exprs: append(oifExprs, getEstablishedExprs(2)...), Exprs: append(exprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif), UserData: []byte(userDataAcceptForwardRuleOif),
} })
r.conn.InsertRule(oifRule)
} }
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) { func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
inputRule := &nftables.Rule{ existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
if existing[userDataAcceptInputRule] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table, Table: chain.Table,
Chain: chain, Chain: chain,
Exprs: []expr.Any{ Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{}, &expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept}, &expr.Verdict{Kind: expr.VerdictAccept},
}, },
UserData: []byte(userDataAcceptInputRule), UserData: []byte(userDataAcceptInputRule),
})
}
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
func (r *router) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return nil, fmt.Errorf("list rules: %w", err)
} }
r.conn.InsertRule(inputRule) present := map[string]bool{}
for _, rule := range rules {
if !isNetbirdAcceptRuleTag(rule.UserData) {
continue
}
present[string(rule.UserData)] = true
}
return present, nil
}
func isNetbirdAcceptRuleTag(userData []byte) bool {
switch string(userData) {
case userDataAcceptForwardRuleIif,
userDataAcceptForwardRuleOif,
userDataAcceptInputRule:
return true
}
return false
} }
func (r *router) removeAcceptFilterRules() error { func (r *router) removeAcceptFilterRules() error {
@@ -1233,13 +1309,17 @@ func (r *router) removeFilterTableRules() error {
return nil return nil
} }
ipt, err := iptables.New() ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil { if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable) return r.removeAcceptRulesFromTable(r.filterTable)
} }
return r.removeAcceptFilterRulesIptables(ipt) if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return nil
} }
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
@@ -1306,7 +1386,7 @@ func (r *router) removeExternalChainsRules() error {
func (r *router) findExternalChains() []*nftables.Chain { func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
for _, family := range families { for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family) allChains, err := r.conn.ListChainsOfTableFamily(family)
@@ -1337,8 +1417,8 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false return false
} }
// Skip all iptables-managed tables in the ip family // Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
return false return false
} }
@@ -1479,7 +1559,7 @@ func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
return rule, nil return rule, nil
} }
protoNum, err := protoToInt(rule.Protocol) protoNum, err := r.af.protoNum(rule.Protocol)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err) return nil, fmt.Errorf("convert protocol to number: %w", err)
} }
@@ -1542,7 +1622,7 @@ func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, rule
dnatExprs = append(dnatExprs, dnatExprs = append(dnatExprs,
&expr.NAT{ &expr.NAT{
Type: expr.NATTypeDestNAT, Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4), Family: uint32(r.af.tableFamily),
RegAddrMin: 1, RegAddrMin: 1,
RegProtoMin: regProtoMin, RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax, RegProtoMax: regProtoMax,
@@ -1635,14 +1715,15 @@ func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule f
}, },
) )
natTable := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
dnatRule := &nftables.Rule{ dnatRule := &nftables.Rule{
Table: &nftables.Table{ Table: natTable,
Name: tableNat,
Family: nftables.TableFamilyIPv4,
},
Chain: &nftables.Chain{ Chain: &nftables.Chain{
Name: chainNameNatPrerouting, Name: chainNameNatPrerouting,
Table: r.filterTable, Table: natTable,
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest, Priority: nftables.ChainPriorityNATDest,
@@ -1673,8 +1754,8 @@ func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: 16, Offset: r.af.dstAddrOffset,
Len: 4, Len: r.af.addrLen,
}, },
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
@@ -1752,7 +1833,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return fmt.Errorf("get set %s: %w", set.HashedName(), err) return fmt.Errorf("get set %s: %w", set.HashedName(), err)
} }
elements := convertPrefixesToSet(prefixes) elements := r.convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil { if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err) return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
} }
@@ -1767,14 +1848,14 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
if _, exists := r.rules[ruleID]; exists { if _, exists := r.rules[ruleID]; exists {
return nil return nil
} }
protoNum, err := protoToInt(protocol) protoNum, err := r.af.protoNum(protocol)
if err != nil { if err != nil {
return fmt.Errorf("convert protocol to number: %w", err) return fmt.Errorf("convert protocol to number: %w", err)
} }
@@ -1801,11 +1882,15 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 3, Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort), Data: binaryutil.BigEndian.PutUint16(originalPort),
}, },
} }
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs, exprs = append(exprs,
&expr.Immediate{ &expr.Immediate{
@@ -1814,11 +1899,11 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
}, },
&expr.Immediate{ &expr.Immediate{
Register: 2, Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort), Data: binaryutil.BigEndian.PutUint16(translatedPort),
}, },
&expr.NAT{ &expr.NAT{
Type: expr.NATTypeDestNAT, Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4), Family: uint32(r.af.tableFamily),
RegAddrMin: 1, RegAddrMin: 1,
RegProtoMin: 2, RegProtoMin: 2,
RegProtoMax: 0, RegProtoMax: 0,
@@ -1843,12 +1928,12 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
} }
// RemoveInboundDNAT removes an inbound DNAT rule. // RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
rule, exists := r.rules[ruleID] rule, exists := r.rules[ruleID]
if !exists { if !exists {
@@ -1894,8 +1979,8 @@ func (r *router) ensureNATOutputChain() error {
} }
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic. // AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
if _, exists := r.rules[ruleID]; exists { if _, exists := r.rules[ruleID]; exists {
return nil return nil
@@ -1905,7 +1990,7 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
return err return err
} }
protoNum, err := protoToInt(protocol) protoNum, err := r.af.protoNum(protocol)
if err != nil { if err != nil {
return fmt.Errorf("convert protocol to number: %w", err) return fmt.Errorf("convert protocol to number: %w", err)
} }
@@ -1926,11 +2011,15 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 2, Register: 2,
Data: binaryutil.BigEndian.PutUint16(sourcePort), Data: binaryutil.BigEndian.PutUint16(originalPort),
}, },
} }
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs, exprs = append(exprs,
&expr.Immediate{ &expr.Immediate{
@@ -1939,11 +2028,11 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
}, },
&expr.Immediate{ &expr.Immediate{
Register: 2, Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort), Data: binaryutil.BigEndian.PutUint16(translatedPort),
}, },
&expr.NAT{ &expr.NAT{
Type: expr.NATTypeDestNAT, Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4), Family: uint32(r.af.tableFamily),
RegAddrMin: 1, RegAddrMin: 1,
RegProtoMin: 2, RegProtoMin: 2,
}, },
@@ -1967,12 +2056,12 @@ func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol,
} }
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule. // RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort)
rule, exists := r.rules[ruleID] rule, exists := r.rules[ruleID]
if !exists { if !exists {
@@ -2011,45 +2100,44 @@ func (r *router) applyNetwork(
} }
if network.IsPrefix() { if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil return r.applyPrefix(network.Prefix, isSource), nil
} }
return nil, nil return nil, nil
} }
// applyPrefix generates nftables expressions for a CIDR prefix // applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any { func (r *router) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset // dst offset by default
offset := uint32(16) offset := r.af.dstAddrOffset
if isSource { if isSource {
// src offset // src offset
offset = 12 offset = r.af.srcAddrOffset
} }
ones := prefix.Bits() ones := prefix.Bits()
// 0.0.0.0/0 doesn't need extra expressions // unspecified address (/0) doesn't need extra expressions
if ones == 0 { if ones == 0 {
return nil return nil
} }
mask := net.CIDRMask(ones, 32) mask := net.CIDRMask(ones, r.af.totalBits)
xor := make([]byte, r.af.addrLen)
return []expr.Any{ return []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: offset, Offset: offset,
Len: 4, Len: r.af.addrLen,
}, },
// netmask
&expr.Bitwise{ &expr.Bitwise{
DestRegister: 1, DestRegister: 1,
SourceRegister: 1, SourceRegister: 1,
Len: 4, Len: r.af.addrLen,
Mask: mask, Mask: mask,
Xor: []byte{0, 0, 0, 0}, Xor: xor,
}, },
// net address
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
@@ -2132,13 +2220,12 @@ func getCtNewExprs() []expr.Any {
} }
} }
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) { func (r *router) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset by default
// dst offset offset := r.af.dstAddrOffset
offset := uint32(16)
if isSource { if isSource {
// src offset // src offset
offset = 12 offset = r.af.srcAddrOffset
} }
return []expr.Any{ return []expr.Any{
@@ -2146,7 +2233,7 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
Offset: offset, Offset: offset,
Len: 4, Len: r.af.addrLen,
}, },
&expr.Lookup{ &expr.Lookup{
SourceRegister: 1, SourceRegister: 1,

View File

@@ -90,8 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
// Build CIDR matching expressions // Build CIDR matching expressions
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true) testRouter := &router{af: afIPv4}
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false) sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order // Combine all expressions in the correct order
// nolint:gocritic // nolint:gocritic
@@ -508,6 +509,136 @@ func TestNftablesCreateIpSet(t *testing.T) {
} }
} }
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTableIPv6()
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IPv6",
sources: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")},
},
{
name: "Multiple IPv6 Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::/48"),
netip.MustParsePrefix("fe80::/10"),
},
},
{
name: "Overlapping IPv6",
sources: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("fd00::1/128"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("fd00::/48"),
},
},
{
name: "Mixed prefix lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::1/128"),
netip.MustParsePrefix("fd00:abcd::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
require.NoError(t, err, "Failed to create IPv6 set")
require.NotNil(t, set)
assert.Equal(t, setName, set.Name)
assert.True(t, set.Interval)
assert.Equal(t, nftables.TypeIP6Addr, set.KeyType)
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd && len(elem.Key) == 16 {
ip := netip.AddrFrom16([16]byte(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "unique prefix count mismatch")
r.conn.DelSet(set)
require.NoError(t, r.conn.Flush())
})
}
}
func createWorkTableIPv6() (*nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return nil, err
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv6})
err = sConn.Flush()
return table, err
}
func deleteWorkTableIPv6() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv6)
if err != nil {
return
}
for _, t := range tables {
if t.Name == tableNameNetbird {
sConn.DelTable(t)
_ = sConn.Flush()
}
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper() t.Helper()
@@ -627,7 +758,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto) expectedProto, _ := afIPv4.protoNum(proto)
for _, e := range exprs { for _, e := range exprs {
switch ex := e.(type) { switch ex := e.(type) {
case *expr.Meta: case *expr.Meta:
@@ -854,3 +985,55 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
} }
assert.Equal(t, 1, found, "NAT rule should exist in kernel") assert.Equal(t, 1, found, "NAT rule should exist in kernel")
} }
func TestCalculateLastIP(t *testing.T) {
tests := []struct {
prefix string
want string
}{
{"10.0.0.0/24", "10.0.0.255"},
{"10.0.0.0/32", "10.0.0.0"},
{"0.0.0.0/0", "255.255.255.255"},
{"192.168.1.0/28", "192.168.1.15"},
{"fd00::/64", "fd00::ffff:ffff:ffff:ffff"},
{"fd00::/128", "fd00::"},
{"2001:db8::/48", "2001:db8:0:ffff:ffff:ffff:ffff:ffff"},
{"::/0", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"},
}
for _, tt := range tests {
t.Run(tt.prefix, func(t *testing.T) {
prefix := netip.MustParsePrefix(tt.prefix)
got := calculateLastIP(prefix)
assert.Equal(t, tt.want, got.String())
})
}
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &router{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),
}
elements := r.convertPrefixesToSet(prefixes)
// Each prefix produces 2 elements (start + end)
require.Len(t, elements, 4)
// fd00::/64 start
assert.Equal(t, netip.MustParseAddr("fd00::").As16(), [16]byte(elements[0].Key))
assert.False(t, elements[0].IntervalEnd)
// fd00::/64 end (fd00:0:0:1::, one past the last)
assert.Equal(t, netip.MustParseAddr("fd00:0:0:1::").As16(), [16]byte(elements[1].Key))
assert.True(t, elements[1].IntervalEnd)
// 2001:db8::1/128 start
assert.Equal(t, netip.MustParseAddr("2001:db8::1").As16(), [16]byte(elements[2].Key))
assert.False(t, elements[2].IntervalEnd)
// 2001:db8::1/128 end (2001:db8::2)
assert.Equal(t, netip.MustParseAddr("2001:db8::2").As16(), [16]byte(elements[3].Key))
assert.True(t, elements[3].IntervalEnd)
}

View File

@@ -5,8 +5,10 @@ import (
"os/exec" "os/exec"
"syscall" "syscall"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -29,15 +31,20 @@ func (m *Manager) Close(*statemanager.Manager) error {
return nil return nil
} }
if !isFirewallRuleActive(firewallRuleName) { var merr *multierror.Error
return nil if isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
}
} }
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil { if isFirewallRuleActive(firewallRuleName + "-v6") {
return fmt.Errorf("couldn't remove windows firewall: %w", err) if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
}
} }
return nil return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -46,17 +53,33 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
if isFirewallRuleActive(firewallRuleName) { if !isFirewallRuleActive(firewallRuleName) {
return nil if err := manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
); err != nil {
return err
}
} }
return manageFirewallRule(firewallRuleName,
addRule, if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
"dir=in", if err := manageFirewallRule(firewallRuleName+"-v6",
"enable=yes", addRule,
"action=allow", "dir=in",
"profile=any", "enable=yes",
"localip="+m.wgIface.Address().IP.String(), "action=allow",
) "profile=any",
"localip="+v6.String(),
); err != nil {
return err
}
}
return nil
} }
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error { func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {

View File

@@ -0,0 +1,125 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
)
func TestTCPCapEvicts(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "4")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 4, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 10; i++ {
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
}
require.LessOrEqual(t, len(tracker.connections), 4,
"TCP table must not exceed the configured cap")
require.Greater(t, len(tracker.connections), 0,
"some entries must remain after eviction")
// The most recently admitted flow must be present: eviction must make
// room for new entries, not silently drop them.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
"newest TCP flow must be admitted after eviction")
// A pre-cap flow must have been evicted to fit the last one.
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
"oldest TCP flow should have been evicted")
}
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "3")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
// Fill to cap with 3 live connections.
for i := 0; i < 3; i++ {
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
}
require.Len(t, tracker.connections, 3)
// Tombstone one by sending RST through IsValidInbound.
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
// Another live connection forces eviction. The tombstone must go first.
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
require.False(t, tombstonedStillPresent,
"tombstoned entry should be evicted before live entries")
require.LessOrEqual(t, len(tracker.connections), 3)
// Both live pre-cap entries must survive: eviction must prefer the
// tombstone, not just satisfy the size bound by dropping any entry.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
}
func TestUDPCapEvicts(t *testing.T) {
t.Setenv(EnvUDPMaxEntries, "5")
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 5, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 12; i++ {
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
}
require.LessOrEqual(t, len(tracker.connections), 5)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
"newest UDP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
"oldest UDP flow should have been evicted")
}
func TestICMPCapEvicts(t *testing.T) {
t.Setenv(EnvICMPMaxEntries, "3")
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 3, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
for i := 0; i < 8; i++ {
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
}
require.LessOrEqual(t, len(tracker.connections), 3)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
"newest ICMP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
"oldest ICMP flow should have been evicted")
}

View File

@@ -1,16 +1,63 @@
package conntrack package conntrack
import ( import (
"fmt" "net"
"net/netip" "net/netip"
"os"
"strconv"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// evictSampleSize bounds how many map entries we scan per eviction call.
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
// heuristic is good enough for a conntrack table that only overflows under
// abuse.
const evictSampleSize = 8
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
// def on empty or invalid; logs a warning on invalid.
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
v := os.Getenv(name)
if v == "" {
return def
}
d, err := time.ParseDuration(v)
if err != nil {
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
}
if d <= 0 {
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return d
}
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
// invalid, or non-positive. Logs a warning on invalid input.
func envInt(logger *nblog.Logger, name string, def int) int {
v := os.Getenv(name)
if v == "" {
return def
}
n, err := strconv.Atoi(v)
switch {
case err != nil:
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
case n <= 0:
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return n
}
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
FlowId uuid.UUID FlowId uuid.UUID
@@ -64,5 +111,7 @@ type ConnKey struct {
} }
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return net.JoinHostPort(c.SrcIP.Unmap().String(), strconv.Itoa(int(c.SrcPort))) +
" → " +
net.JoinHostPort(c.DstIP.Unmap().String(), strconv.Itoa(int(c.DstPort)))
} }

View File

@@ -13,6 +13,54 @@ import (
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestConnKey_String(t *testing.T) {
tests := []struct {
name string
key ConnKey
expect string
}{
{
name: "IPv4",
key: ConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
SrcPort: 12345,
DstPort: 80,
},
expect: "192.168.1.1:12345 → 10.0.0.1:80",
},
{
name: "IPv6",
key: ConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
SrcPort: 54321,
DstPort: 443,
},
expect: "[2001:db8::1]:54321 → [2001:db8::2]:443",
},
{
name: "IPv4-mapped IPv6 unmaps",
key: ConnKey{
SrcIP: netip.MustParseAddr("::ffff:10.0.0.1"),
DstIP: netip.MustParseAddr("::ffff:10.0.0.2"),
SrcPort: 1000,
DstPort: 2000,
},
expect: "10.0.0.1:1000 → 10.0.0.2:2000",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {

View File

@@ -0,0 +1,11 @@
//go:build !ios && !android
package conntrack
// Default per-tracker entry caps on desktop/server platforms. These mirror
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
const (
DefaultMaxTCPEntries = 65536
DefaultMaxUDPEntries = 16384
DefaultMaxICMPEntries = 2048
)

View File

@@ -0,0 +1,13 @@
//go:build ios || android
package conntrack
// Default per-tracker entry caps on mobile platforms. iOS network extensions
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
// is ~200 B plus map overhead).
const (
DefaultMaxTCPEntries = 4096
DefaultMaxUDPEntries = 2048
DefaultMaxICMPEntries = 512
)

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strconv"
"sync" "sync"
"time" "time"
@@ -21,9 +22,14 @@ const (
// ICMPCleanupInterval is how often we check for stale ICMP connections // ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info, // MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info.
// which includes the IP header (20 bytes) and transport header (8 bytes) // IPv4: 20-byte header + 8-byte transport = 28 bytes.
MaxICMPPayloadLength = 28 // IPv6: 40-byte header + 8-byte transport = 48 bytes.
MaxICMPPayloadLength = 48
// minICMPPayloadIPv4 is the minimum embedded packet length for IPv4 ICMP errors.
minICMPPayloadIPv4 = 28
// minICMPPayloadIPv6 is the minimum embedded packet length for IPv6 ICMP errors.
minICMPPayloadIPv6 = 48
) )
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
@@ -44,6 +50,9 @@ type ICMPConnTrack struct {
ICMPCode uint8 ICMPCode uint8
} }
// EnvICMPMaxEntries caps the ICMP conntrack table size.
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
type ICMPTracker struct { type ICMPTracker struct {
logger *nblog.Logger logger *nblog.Logger
@@ -52,6 +61,7 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -65,7 +75,7 @@ type ICMPInfo struct {
// String implements fmt.Stringer for lazy evaluation in log messages // String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string { func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength { if info.isErrorMessage() && info.PayloadLen >= minICMPPayloadIPv4 {
if origInfo := info.parseOriginalPacket(); origInfo != "" { if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo) return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
} }
@@ -74,42 +84,72 @@ func (info ICMPInfo) String() string {
return info.TypeCode.String() return info.TypeCode.String()
} }
// isErrorMessage returns true if this ICMP type carries original packet info // isErrorMessage returns true if this ICMP type carries original packet info.
// Covers both ICMPv4 and ICMPv6 error types. Without a family field we match
// both sets; type 3 overlaps (v4 DestUnreachable / v6 TimeExceeded) so it's
// kept as a literal.
func (info ICMPInfo) isErrorMessage() bool { func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type() typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable // ICMPv4 error types
typ == 5 || // Redirect if typ == layers.ICMPv4TypeDestinationUnreachable ||
typ == 11 || // Time Exceeded typ == layers.ICMPv4TypeRedirect ||
typ == 12 // Parameter Problem typ == layers.ICMPv4TypeTimeExceeded ||
typ == layers.ICMPv4TypeParameterProblem {
return true
}
// ICMPv6 error types (type 3 already matched above as v4 DestUnreachable)
if typ == layers.ICMPv6TypeDestinationUnreachable ||
typ == layers.ICMPv6TypePacketTooBig ||
typ == layers.ICMPv6TypeParameterProblem {
return true
}
return false
} }
// parseOriginalPacket extracts info about the original packet from ICMP payload // parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string { func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength { if info.PayloadLen == 0 {
return "" return ""
} }
// TODO: handle IPv6 version := (info.PayloadData[0] >> 4) & 0xF
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
var protocol uint8
var srcIP, dstIP net.IP
var transportData []byte
switch version {
case 4:
if info.PayloadLen < minICMPPayloadIPv4 {
return ""
}
protocol = info.PayloadData[9]
srcIP = net.IP(info.PayloadData[12:16])
dstIP = net.IP(info.PayloadData[16:20])
transportData = info.PayloadData[20:]
case 6:
if info.PayloadLen < minICMPPayloadIPv6 {
return ""
}
// Next Header field in IPv6 header
protocol = info.PayloadData[6]
srcIP = net.IP(info.PayloadData[8:24])
dstIP = net.IP(info.PayloadData[24:40])
transportData = info.PayloadData[40:]
default:
return "" return ""
} }
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) { switch nftypes.Protocol(protocol) {
case nftypes.TCP: case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort) return "TCP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
case nftypes.UDP: case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1]) srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3]) dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort) return "UDP " + net.JoinHostPort(srcIP.String(), strconv.Itoa(int(srcPort))) + " → " + net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort)))
case nftypes.ICMP: case nftypes.ICMP:
icmpType := transportData[0] icmpType := transportData[0]
@@ -135,6 +175,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
flowLogger: flowLogger, flowLogger: flowLogger,
} }
@@ -221,7 +262,9 @@ func (t *ICMPTracker) track(
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@@ -240,16 +283,22 @@ func (t *ICMPTracker) track(
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
t.mutex.Lock() t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request.
// Accepts both ICMPv4 (type 0) and ICMPv6 (type 129) echo replies.
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool { func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) { if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
return false return false
} }
@@ -286,6 +335,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
} }
} }
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
func (t *ICMPTracker) evictOneLocked() {
var candKey ICMPConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) cleanup() {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
@@ -294,13 +371,22 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", if t.logger.Enabled(nblog.LevelTrace) {
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
func icmpProtocolForAddr(ip netip.Addr) nftypes.Protocol {
if ip.Is6() {
return nftypes.ICMPv6
}
return nftypes.ICMP
}
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.tickerCancel() t.tickerCancel()
@@ -316,7 +402,7 @@ func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []
Type: typ, Type: typ,
RuleID: ruleID, RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 Protocol: icmpProtocolForAddr(conn.SourceIP),
SourceIP: conn.SourceIP, SourceIP: conn.SourceIP,
DestIP: conn.DestIP, DestIP: conn.DestIP,
ICMPType: conn.ICMPType, ICMPType: conn.ICMPType,
@@ -334,7 +420,7 @@ func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Ad
Type: nftypes.TypeStart, Type: nftypes.TypeStart,
RuleID: ruleID, RuleID: ruleID,
Direction: direction, Direction: direction,
Protocol: nftypes.ICMP, Protocol: icmpProtocolForAddr(srcIP),
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
ICMPType: typ, ICMPType: typ,

View File

@@ -5,6 +5,42 @@ import (
"testing" "testing"
) )
func TestICMPConnKey_String(t *testing.T) {
tests := []struct {
name string
key ICMPConnKey
expect string
}{
{
name: "IPv4",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("192.168.1.1"),
DstIP: netip.MustParseAddr("10.0.0.1"),
ID: 1234,
},
expect: "192.168.1.1 → 10.0.0.1 (id 1234)",
},
{
name: "IPv6",
key: ICMPConnKey{
SrcIP: netip.MustParseAddr("2001:db8::1"),
DstIP: netip.MustParseAddr("2001:db8::2"),
ID: 5678,
},
expect: "2001:db8::1 → 2001:db8::2 (id 5678)",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.key.String()
if got != tc.expect {
t.Errorf("got %q, want %q", got, tc.expect)
}
})
}
}
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)

View File

@@ -38,6 +38,27 @@ const (
TCPHandshakeTimeout = 60 * time.Second TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections // TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute TCPCleanupInterval = 5 * time.Minute
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
FinWaitTimeout = 60 * time.Second
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
// holding CloseWait longer than this should bump the env var.
CloseWaitTimeout = 60 * time.Second
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
LastAckTimeout = 30 * time.Second
)
// Env vars to override per-state teardown timeouts. Values parsed by
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
// defaults above with a warning.
const (
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
// (tombstones first) are evicted when the cap is reached.
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
) )
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
@@ -133,14 +154,18 @@ func (t *TCPConnTrack) SetTombstone() {
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
type TCPTracker struct { type TCPTracker struct {
logger *nblog.Logger logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
waitTimeout time.Duration waitTimeout time.Duration
flowLogger nftypes.FlowLogger finWaitTimeout time.Duration
closeWaitTimeout time.Duration
lastAckTimeout time.Duration
maxEntries int
flowLogger nftypes.FlowLogger
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
@@ -155,13 +180,17 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
timeout: timeout, timeout: timeout,
waitTimeout: waitTimeout, waitTimeout: waitTimeout,
flowLogger: flowLogger, finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
flowLogger: flowLogger,
} }
go tracker.cleanupRoutine(ctx) go tracker.cleanupRoutine(ctx)
@@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
if exists || flags&TCPSyn == 0 { if exists || flags&TCPSyn == 0 {
return return
} }
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
// create spurious conntrack entries. Not mandated by RFC 9293 but a
// common hardening (Linux netfilter/nftables rejects these too).
if !isValidFlagCombination(flags) {
return
}
conn := &TCPConnTrack{ conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
@@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.state.Store(int32(TCPStateNew)) conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort)) conn.DNATOrigPort.Store(uint32(origPort))
if origPort != 0 { if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) if origPort != 0 {
} else { t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
t.logger.Trace2("New %s TCP connection: %s", direction, key) } else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
} }
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
t.mutex.Lock() t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
// iteration order is randomized), preferring a tombstone. If no tombstone
// found in the sample, evicts the oldest among the sampled entries. O(1)
// worst case — cheap enough to run on every insert at cap during abuse.
func (t *TCPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
if c.IsTombstone() {
delete(t.connections, k)
return
}
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
// TypeEnd is already emitted at the state transition to
// TimeWait and when a connection is tombstoned. Only emit
// here when we're reaping a still-active flow.
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
}
delete(t.connections, candKey)
}
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool { func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{ key := ConnKey{
@@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return false return false
} }
// Reject illegal flag combinations regardless of state. These never belong
// to a legitimate flow and must not advance or tear down state.
if !isValidFlagCombination(flags) {
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
}
return false
}
currentState := conn.GetState() currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) { if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) if t.logger.Enabled(nblog.LevelWarn) {
// allow all flags for established for now t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
if currentState == TCPStateEstablished {
return true
} }
return false return false
} }
@@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return true return true
} }
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags.
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size) conn.UpdateCounters(packetDir, size)
// Malformed flag combinations must not refresh lastSeen or drive state,
// otherwise spoofed packets keep a dead flow alive past its timeout.
if !isValidFlagCombination(flags) {
return
}
conn.UpdateLastSeen()
currentState := conn.GetState() currentState := conn.GetState()
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) { // Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
conn.SetTombstone() // cannot apply the RFC 5961 in-window RST check, so we conservatively
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", // reject RSTs that the spec would accept (TIME-WAIT with in-window
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) // SEQ, SynSent from same direction as own SYN, etc.).
t.sendEvent(nftypes.TypeEnd, conn, nil) t.handleRst(key, conn, currentState, packetDir)
}
return return
} }
var newState TCPState newState := nextState(currentState, conn.Direction, packetDir, flags)
switch currentState { if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
case TCPStateNew: return
if flags&TCPSyn != 0 && flags&TCPAck == 0 { }
if conn.Direction == nftypes.Egress { t.onTransition(key, conn, currentState, newState, packetDir)
newState = TCPStateSynSent }
} else {
newState = TCPStateSynReceived
}
}
case TCPStateSynSent: // handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
if flags&TCPSyn != 0 && flags&TCPAck != 0 { // from the SYN direction are ignored; otherwise the flow is tombstoned.
if packetDir != conn.Direction { func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
newState = TCPStateEstablished // TimeWait exists to absorb late segments; don't let a late RST
} else { // tombstone the entry and break same-4-tuple reuse.
// Simultaneous open if currentState == TCPStateTimeWait {
newState = TCPStateSynReceived return
} }
} // A RST from the same direction as the SYN cannot be a legitimate
// response and must not tear down a half-open connection.
if currentState == TCPStateSynSent && packetDir == conn.Direction {
return
}
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
return
}
conn.SetTombstone()
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateSynReceived: // stateTransition describes one state's transition logic. It receives the
if flags&TCPAck != 0 && flags&TCPSyn == 0 { // packet's flags plus whether the packet direction matches the connection's
if packetDir == conn.Direction { // origin direction (same=true means same side as the SYN initiator). Return 0
newState = TCPStateEstablished // for no transition.
} type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
}
case TCPStateEstablished: // stateTable maps each state to its transition function. Centralized here so
if flags&TCPFin != 0 { // nextState stays trivial and each rule is easy to read in isolation.
if packetDir == conn.Direction { var stateTable = map[TCPState]stateTransition{
newState = TCPStateFinWait1 TCPStateNew: transNew,
} else { TCPStateSynSent: transSynSent,
newState = TCPStateCloseWait TCPStateSynReceived: transSynReceived,
} TCPStateEstablished: transEstablished,
} TCPStateFinWait1: transFinWait1,
TCPStateFinWait2: transFinWait2,
TCPStateClosing: transClosing,
TCPStateCloseWait: transCloseWait,
TCPStateLastAck: transLastAck,
}
case TCPStateFinWait1: // nextState returns the target TCP state for the given current state and
if packetDir != conn.Direction { // packet, or 0 if the packet does not trigger a transition.
switch { func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
case flags&TCPFin != 0 && flags&TCPAck != 0: fn, ok := stateTable[currentState]
newState = TCPStateClosing if !ok {
case flags&TCPFin != 0: return 0
newState = TCPStateClosing }
case flags&TCPAck != 0: return fn(flags, connDir, packetDir == connDir)
newState = TCPStateFinWait2 }
}
}
case TCPStateFinWait2: func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
if flags&TCPFin != 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
newState = TCPStateTimeWait if connDir == nftypes.Egress {
return TCPStateSynSent
} }
return TCPStateSynReceived
}
return 0
}
case TCPStateClosing: func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 { if flags&TCPSyn != 0 && flags&TCPAck != 0 {
newState = TCPStateTimeWait if same {
return TCPStateSynReceived // simultaneous open
} }
return TCPStateEstablished
}
return 0
}
case TCPStateCloseWait: func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 { if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
newState = TCPStateLastAck return TCPStateEstablished
} }
return 0
}
case TCPStateLastAck: func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 { if flags&TCPFin == 0 {
newState = TCPStateClosed return 0
} }
if same {
return TCPStateFinWait1
}
return TCPStateCloseWait
}
// transFinWait1 handles the active-close peer response. A FIN carrying our
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
if same {
return 0
}
if flags&TCPFin != 0 && flags&TCPAck != 0 {
return TCPStateTimeWait
}
switch {
case flags&TCPFin != 0:
return TCPStateClosing
case flags&TCPAck != 0:
return TCPStateFinWait2
}
return 0
}
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transClosing completes a simultaneous close on the peer's ACK.
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && same {
return TCPStateLastAck
}
return 0
}
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateClosed
}
return 0
}
// onTransition handles logging and flow-event emission after a successful
// state transition. TimeWait and Closed are terminal for flow accounting.
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
traceOn := t.logger.Enabled(nblog.LevelTrace)
if traceOn {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
} }
if newState != 0 && conn.CompareAndSwapState(currentState, newState) { switch to {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir) case TCPStateTimeWait:
if traceOn {
switch newState {
case TCPStateTimeWait:
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]", t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil) }
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed: case TCPStateClosed:
conn.SetTombstone() conn.SetTombstone()
if traceOn {
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]", t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
// isValidStateForFlags checks if the TCP flags are valid for the current connection state // isValidStateForFlags checks if the TCP flags are valid for the current
// connection state. Caller must have already verified the flag combination is
// legal via isValidFlagCombination.
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
if state == TCPStateSynSent { if state == TCPStateSynSent {
return flags&TCPAck != 0 return flags&TCPAck != 0
@@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() {
timeout = t.waitTimeout timeout = t.waitTimeout
case TCPStateEstablished: case TCPStateEstablished:
timeout = t.timeout timeout = t.timeout
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
timeout = t.finWaitTimeout
case TCPStateCloseWait:
timeout = t.closeWaitTimeout
case TCPStateLastAck:
timeout = t.lastAckTimeout
default: default:
// SynSent / SynReceived / New
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
if conn.timeoutExceeded(timeout) { if conn.timeoutExceeded(timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", if t.logger.Enabled(nblog.LevelTrace) {
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
// event already handled by state change // event already handled by state change
if currentState != TCPStateTimeWait { if currentState != TCPStateTimeWait {

View File

@@ -0,0 +1,100 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
// RST hygiene tests: the tracker currently closes the flow on any RST that
// matches the 4-tuple, regardless of direction or state. These tests cover
// the minimum checks we want (no SEQ tracking).
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateSynSent, conn.GetState())
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
// cannot be a legitimate response. It must not close the connection.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
require.Equal(t, TCPStateSynSent, conn.GetState(),
"RST in same direction as SYN must not close connection")
require.False(t, conn.IsTombstone())
}
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Drive to TIME-WAIT via active close.
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
// exists to absorb late segments).
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState(),
"RST in TIME-WAIT must not transition state")
require.False(t, conn.IsTombstone(),
"RST in TIME-WAIT must not tombstone the entry")
}
func TestTCPIllegalFlagCombos(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
// Illegal combos must be rejected and must not change state.
combos := []struct {
name string
flags uint8
}{
{"SYN+RST", TCPSyn | TCPRst},
{"FIN+RST", TCPFin | TCPRst},
{"SYN+FIN", TCPSyn | TCPFin},
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
}
for _, c := range combos {
t.Run(c.name, func(t *testing.T) {
before := conn.GetState()
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
require.Equal(t, before, conn.GetState(),
"illegal flag combo must not change state")
require.False(t, conn.IsTombstone())
})
}
}

View File

@@ -0,0 +1,235 @@
package conntrack
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// These tests exercise cases where the TCP state machine currently advances
// on retransmitted or wrong-direction segments and tears the flow down
// prematurely. They are expected to fail until the direction checks are added.
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer sends FIN -> CloseWait (our app has not yet closed).
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
// sent our FIN yet, so state must remain CloseWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "retransmitted peer FIN must still be accepted")
require.Equal(t, TCPStateCloseWait, conn.GetState(),
"retransmitted peer FIN must not advance CloseWait to LastAck")
// Our app finally closes -> LastAck.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Peer ACK closes.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// We initiate close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Stray retransmit of our own FIN (same direction as originator) must
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait2, conn.GetState(),
"own FIN retransmit must not advance FinWait2 to TimeWait")
// Peer FIN -> TimeWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
}
func TestTCPLastAckDirectionCheck(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateLastAck, conn.GetState())
// Our own ACK retransmit (same direction as originator) must NOT close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState(),
"own ACK retransmit in LastAck must not transition to Closed")
// Peer's ACK -> Closed.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Our own ACK retransmit (same direction as originator) must not advance.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState(),
"own ACK in FinWait1 must not advance to FinWait2")
}
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
// Verify cleanup reaps entries in each teardown state at the configured
// per-state timeout, not at the single handshake timeout.
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
t.Setenv(EnvTCPLastAckTimeout, "30ms")
dstIP := netip.MustParseAddr("100.64.0.2")
dstPort := uint16(80)
// Drives a connection to the target state, forces its lastSeen well
// beyond the configured timeout, runs cleanup, and asserts reaping.
cases := []struct {
name string
// drive takes a fresh tracker and returns the conn key after
// transitioning the flow into the intended teardown state.
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
}{
{
name: "FinWait1",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
},
},
{
name: "FinWait2",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
},
},
{
name: "CloseWait",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
},
},
{
name: "LastAck",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
},
},
}
// Use a unique source port per subtest so nothing aliases.
port := uint16(12345)
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
srcIP := netip.MustParseAddr("100.64.0.1")
port++
key, wantState := c.drive(t, tracker, srcIP, port)
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, wantState, conn.GetState())
// Age the entry past the largest per-state timeout.
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
tracker.cleanup()
_, exists := tracker.connections[key]
require.False(t, exists, "%s entry should be reaped", c.name)
})
}
}
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
// teardown states, which some stacks emit during close.
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer FIN -> CloseWait.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
// Peer pushes trailing data + FIN|PSH|ACK (legal).
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
"FIN|PSH|ACK in CloseWait must be accepted")
// Bare ACK keepalive from peer in CloseWait must be accepted.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
"bare ACK in CloseWait must be accepted")
}

View File

@@ -17,6 +17,9 @@ const (
DefaultUDPTimeout = 30 * time.Second DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections // UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second UDPCleanupInterval = 15 * time.Second
// EnvUDPMaxEntries caps the UDP conntrack table size.
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
) )
// UDPConnTrack represents a UDP connection state // UDPConnTrack represents a UDP connection state
@@ -34,6 +37,7 @@ type UDPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
flowLogger: flowLogger, flowLogger: flowLogger,
} }
@@ -117,13 +122,18 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
t.mutex.Lock() t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
if origPort != 0 { if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) if origPort != 0 {
} else { t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
t.logger.Trace2("New %s UDP connection: %s", direction, key) } else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
} }
t.sendEvent(nftypes.TypeStart, conn, ruleID) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
@@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
return true return true
} }
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample: picks the oldest among up to evictSampleSize entries.
func (t *UDPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) { func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop() defer t.cleanupTicker.Stop()
@@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", if t.logger.Enabled(nblog.LevelTrace) {
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }

View File

@@ -18,9 +18,10 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
@@ -35,8 +36,10 @@ import (
const ( const (
layerTypeAll = 255 layerTypeAll = 255
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipv4TCPHeaderMinSize represents minimum IPv4 (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40 ipv4TCPHeaderMinSize = 40
// ipv6TCPHeaderMinSize represents minimum IPv6 (40) + TCP (20) header size for MSS calculation
ipv6TCPHeaderMinSize = 60
) )
// serviceKey represents a protocol/port combination for netstack service registry // serviceKey represents a protocol/port combination for netstack service registry
@@ -123,7 +126,7 @@ type Manager struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
blockRule firewall.Rule blockRules []firewall.Rule
// Internal 1:1 DNAT // Internal 1:1 DNAT
dnatEnabled atomic.Bool dnatEnabled atomic.Bool
@@ -138,9 +141,10 @@ type Manager struct {
netstackServices map[serviceKey]struct{} netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex netstackServiceMutex sync.RWMutex
mtu uint16 mtu uint16
mssClampValue uint16 mssClampValueIPv4 uint16
mssClampEnabled bool mssClampValueIPv6 uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only. // Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[common.PacketHook] udpHookOut atomic.Pointer[common.PacketHook]
@@ -157,11 +161,28 @@ type decoder struct {
icmp4 layers.ICMPv4 icmp4 layers.ICMPv4
icmp6 layers.ICMPv6 icmp6 layers.ICMPv6
decoded []gopacket.LayerType decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser parser4 *gopacket.DecodingLayerParser
parser6 *gopacket.DecodingLayerParser
dnatOrigPort uint16 dnatOrigPort uint16
} }
// decodePacket decodes packet data using the appropriate parser based on IP version.
func (d *decoder) decodePacket(data []byte) error {
if len(data) == 0 {
return errors.New("empty packet")
}
version := data[0] >> 4
switch version {
case 4:
return d.parser4.DecodeLayers(data, &d.decoded)
case 6:
return d.parser6.DecodeLayers(data, &d.decoded)
default:
return fmt.Errorf("unknown IP version %d", version)
}
}
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu) return create(iface, nil, disableServerRoutes, flowLogger, mtu)
@@ -219,11 +240,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d return d
}, },
}, },
@@ -249,7 +276,12 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if !disableMSSClamping { if !disableMSSClamping {
m.mssClampEnabled = true m.mssClampEnabled = true
m.mssClampValue = mtu - ipTCPHeaderMinSize if mtu > ipv4TCPHeaderMinSize {
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
}
if mtu > ipv6TCPHeaderMinSize {
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
}
} }
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
@@ -272,13 +304,25 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
return m, nil return m, nil
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { // blockInvalidRouted installs drop rules for traffic to the wg overlay that
// arrives via the routing path. v4 and v6 are independent: a v6 install
// failure leaves v4 protection in place (and vice versa) so the returned
// slice always contains whatever was successfully installed, even on error.
// Callers must persist the slice so DisableRouting can clean partial state.
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) {
wgPrefix := iface.Address().Network wgPrefix := iface.Address().Network
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering( sources := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
v6Net := iface.Address().IPv6Net
if v6Net.IsValid() {
sources = append(sources, netip.PrefixFrom(netip.IPv6Unspecified(), 0))
}
var rules []firewall.Rule
v4Rule, err := m.addRouteFiltering(
nil, nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, sources,
firewall.Network{Prefix: wgPrefix}, firewall.Network{Prefix: wgPrefix},
firewall.ProtocolALL, firewall.ProtocolALL,
nil, nil,
@@ -286,12 +330,30 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, e
firewall.ActionDrop, firewall.ActionDrop,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("block wg nte : %w", err) return rules, fmt.Errorf("block wg v4 net: %w", err)
}
rules = append(rules, v4Rule)
if v6Net.IsValid() {
log.Debugf("blocking invalid routed traffic for %s", v6Net)
v6Rule, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: v6Net},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
)
if err != nil {
return rules, fmt.Errorf("block wg v6 net: %w", err)
}
rules = append(rules, v6Rule)
} }
// TODO: Block networks that we're a client of // TODO: Block networks that we're a client of
return rule, nil return rules, nil
} }
func (m *Manager) determineRouting() error { func (m *Manager) determineRouting() error {
@@ -521,7 +583,7 @@ func (m *Manager) addRouteFiltering(
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, dstSet: destination.Set,
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4), protoLayer: protoToLayer(proto, ipLayerFromPrefix(destination.Prefix)),
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
@@ -612,10 +674,10 @@ func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers. // resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held. // Must be called with m.mutex held.
func (m *Manager) resetState() { func (m *Manager) resetState() {
maps.Clear(m.outgoingRules) clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules) clear(m.incomingDenyRules)
maps.Clear(m.incomingRules) clear(m.incomingRules)
maps.Clear(m.routeRulesMap) clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0] m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil) m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil) m.tcpHookOut.Store(nil)
@@ -676,11 +738,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
} }
destinations := matches[0].destinations destinations := matches[0].destinations
for _, prefix := range prefixes { destinations = append(destinations, prefixes...)
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int { slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr()) cmp := a.Addr().Compare(b.Addr())
@@ -719,7 +777,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
return false return false
} }
@@ -729,7 +787,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() { if !srcIP.IsValid() {
m.logger.Error1("Unknown network layer: %v", d.decoded[0]) if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
}
return false return false
} }
@@ -803,12 +863,32 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false return false
} }
var mssClampValue uint16
var ipHeaderSize int
switch d.decoded[0] {
case layers.LayerTypeIPv4:
mssClampValue = m.mssClampValueIPv4
ipHeaderSize = int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
case layers.LayerTypeIPv6:
mssClampValue = m.mssClampValueIPv6
ipHeaderSize = 40
default:
return false
}
if mssClampValue == 0 {
return false
}
mssOptionIndex := -1 mssOptionIndex := -1
var currentMSS uint16 var currentMSS uint16
for i, opt := range d.tcp.Options { for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData) currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > m.mssClampValue { if currentMSS > mssClampValue {
mssOptionIndex = i mssOptionIndex = i
break break
} }
@@ -819,20 +899,17 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false return false
} }
ipHeaderSize := int(d.ip4.IHL) * 4 if !m.updateMSSOption(packetData, d, mssOptionIndex, mssClampValue, ipHeaderSize) {
if ipHeaderSize < 20 {
return false return false
} }
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { if m.logger.Enabled(nblog.LevelTrace) {
return false m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
} }
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true return true
} }
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex int, mssClampValue uint16, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20 tcpOptionsStart := tcpHeaderStart + 20
@@ -847,7 +924,7 @@ func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex,
} }
mssValueOffset := optOffset + 2 mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true return true
@@ -857,18 +934,32 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
tcpLayer := packetData[tcpHeaderStart:] tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart tcpLength := len(packetData) - tcpHeaderStart
// Zero out existing checksum
tcpLayer[16] = 0 tcpLayer[16] = 0
tcpLayer[17] = 0 tcpLayer[17] = 0
// Build pseudo-header checksum based on IP version
var pseudoSum uint32 var pseudoSum uint32
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) switch d.decoded[0] {
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) case layers.LayerTypeIPv4:
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.Protocol) pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(tcpLength) pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
case layers.LayerTypeIPv6:
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.SrcIP[i])<<8 | uint32(d.ip6.SrcIP[i+1])
}
for i := 0; i < 16; i += 2 {
pseudoSum += uint32(d.ip6.DstIP[i])<<8 | uint32(d.ip6.DstIP[i+1])
}
pseudoSum += uint32(tcpLength)
pseudoSum += uint32(layers.IPProtocolTCP)
}
var sum = pseudoSum sum := pseudoSum
for i := 0; i < tcpLength-1; i += 2 { for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
} }
@@ -906,6 +997,9 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData
} }
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackOutbound(srcIP, dstIP, id, tc, d.icmp6.Payload, size)
} }
} }
@@ -919,6 +1013,9 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
case layers.LayerTypeICMPv6:
id, tc := icmpv6EchoFields(d)
m.icmpTracker.TrackInbound(srcIP, dstIP, id, tc, ruleID, d.icmp6.Payload, size)
} }
d.dnatOrigPort = 0 d.dnatOrigPort = 0
@@ -951,15 +1048,21 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder // TODO: pass fragments of routed packets to forwarder
if fragment { if fragment {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", if m.logger.Enabled(nblog.LevelTrace) {
srcIP, dstIP, d.ip4.Id, d.ip4.Flags) if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
}
}
return false return false
} }
// TODO: optimize port DNAT by caching matched rules in conntrack // TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information // Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true return true
} }
@@ -968,7 +1071,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
if translated := m.translateInboundReverse(packetData, d); translated { if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses // Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true return true
} }
@@ -994,8 +1097,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
pnum := getProtocolFromPacket(d) pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", if m.logger.Enabled(nblog.LevelTrace) {
ruleID, pnum, srcIP, srcPort, dstIP, dstPort) m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
@@ -1045,8 +1150,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled.Load() { if !m.routingEnabled.Load() {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", if m.logger.Enabled(nblog.LevelTrace) {
srcIP, dstIP) m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
}
return true return true
} }
@@ -1063,8 +1170,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if !pass { if !pass {
proto := getProtocolFromPacket(d) proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", if m.logger.Enabled(nblog.LevelTrace) {
ruleID, proto, srcIP, srcPort, dstIP, dstPort) m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
@@ -1100,6 +1209,48 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true return true
} }
// icmpv6EchoFields extracts the echo identifier from an ICMPv6 packet and maps
// the ICMPv6 type code to an ICMPv4TypeCode so the ICMP conntrack can handle
// both families uniformly. The echo ID is in the first two payload bytes.
func icmpv6EchoFields(d *decoder) (id uint16, tc layers.ICMPv4TypeCode) {
if len(d.icmp6.Payload) >= 2 {
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
}
// Map ICMPv6 echo types to ICMPv4 equivalents for unified tracking.
switch d.icmp6.TypeCode.Type() {
case layers.ICMPv6TypeEchoRequest:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)
case layers.ICMPv6TypeEchoReply:
tc = layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoReply, 0)
default:
tc = layers.CreateICMPv4TypeCode(d.icmp6.TypeCode.Type(), d.icmp6.TypeCode.Code())
}
return id, tc
}
// protoLayerMatches checks if a packet's protocol layer matches a rule's expected
// protocol layer. ICMPv4 and ICMPv6 are treated as equivalent when matching
// ICMP rules since management sends a single ICMP rule for both families.
func protoLayerMatches(ruleLayer, packetLayer gopacket.LayerType) bool {
if ruleLayer == packetLayer {
return true
}
if ruleLayer == layers.LayerTypeICMPv4 && packetLayer == layers.LayerTypeICMPv6 {
return true
}
if ruleLayer == layers.LayerTypeICMPv6 && packetLayer == layers.LayerTypeICMPv4 {
return true
}
return false
}
func ipLayerFromPrefix(p netip.Prefix) gopacket.LayerType {
if p.Addr().Is6() {
return layers.LayerTypeIPv6
}
return layers.LayerTypeIPv4
}
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType { func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto { switch proto {
case firewall.ProtocolTCP: case firewall.ProtocolTCP:
@@ -1123,8 +1274,10 @@ func getProtocolFromPacket(d *decoder) nftypes.Protocol {
return nftypes.TCP return nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return nftypes.UDP return nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4:
return nftypes.ICMP return nftypes.ICMP
case layers.LayerTypeICMPv6:
return nftypes.ICMPv6
default: default:
return nftypes.ProtocolUnknown return nftypes.ProtocolUnknown
} }
@@ -1145,8 +1298,10 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, false if the packet is valid and not a fragment. // It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid. // It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
m.logger.Trace1("couldn't decode packet, err: %s", err) if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace1("couldn't decode packet, err: %s", err)
}
return false, false return false, false
} }
@@ -1158,10 +1313,21 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
} }
// Fragments are also valid // Fragments are also valid
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 { if l == 1 {
ip4 := d.ip4 switch d.decoded[0] {
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 { case layers.LayerTypeIPv4:
return true, true if d.ip4.Flags&layers.IPv4MoreFragments != 0 || d.ip4.FragOffset != 0 {
return true, true
}
case layers.LayerTypeIPv6:
// IPv6 uses Fragment extension header (NextHeader=44). If gopacket
// only decoded the IPv6 layer, the transport is in a fragment.
// TODO: handle non-Fragment extension headers (HopByHop, Routing,
// DestOpts) by walking the chain. gopacket's parser does not
// support them as DecodingLayers; today we drop such packets.
if d.ip6.NextHeader == layers.IPProtocolIPv6Fragment {
return true, true
}
} }
} }
@@ -1199,21 +1365,35 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
size, size,
) )
// TODO: ICMPv6 case layers.LayerTypeICMPv6:
id, _ := icmpv6EchoFields(d)
return m.icmpTracker.IsValidInbound(
srcIP,
dstIP,
id,
d.icmp6.TypeCode.Type(),
size,
)
} }
return false return false
} }
// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed // isSpecialICMP returns true if the packet is a special ICMP error packet that should be allowed.
func (m *Manager) isSpecialICMP(d *decoder) bool { func (m *Manager) isSpecialICMP(d *decoder) bool {
if d.decoded[1] != layers.LayerTypeICMPv4 { switch d.decoded[1] {
return false case layers.LayerTypeICMPv4:
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
case layers.LayerTypeICMPv6:
icmpType := d.icmp6.TypeCode.Type()
return icmpType == layers.ICMPv6TypeDestinationUnreachable ||
icmpType == layers.ICMPv6TypePacketTooBig ||
icmpType == layers.ICMPv6TypeTimeExceeded ||
icmpType == layers.ICMPv6TypeParameterProblem
} }
return false
icmpType := d.icmp4.TypeCode.Type()
return icmpType == layers.ICMPv4TypeDestinationUnreachable ||
icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) { func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
@@ -1270,7 +1450,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true return rule.mgmtId, rule.drop, true
} }
if payloadLayer != rule.protoLayer { if !protoLayerMatches(rule.protoLayer, payloadLayer) {
continue continue
} }
@@ -1305,8 +1485,7 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.Lay
} }
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules if rule.protoLayer != layerTypeAll && !protoLayerMatches(rule.protoLayer, protoLayer) {
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false return false
} }
@@ -1367,13 +1546,14 @@ func (m *Manager) EnableRouting() error {
return nil return nil
} }
rule, err := m.blockInvalidRouted(m.wgIface) rules, err := m.blockInvalidRouted(m.wgIface)
// Persist whatever was installed even on partial failure, so DisableRouting
// can clean it up later.
m.blockRules = rules
if err != nil { if err != nil {
return fmt.Errorf("block invalid routed: %w", err) return fmt.Errorf("block invalid routed: %w", err)
} }
m.blockRule = rule
return nil return nil
} }
@@ -1389,9 +1569,16 @@ func (m *Manager) DisableRouting() error {
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
m.nativeRouter.Store(false) m.nativeRouter.Store(false)
// don't stop forwarder if in use by netstack var merr *multierror.Error
for _, rule := range m.blockRules {
if err := m.deleteRouteRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete block rule: %w", err))
}
}
m.blockRules = nil
if m.netstack && m.localForwarding { if m.netstack && m.localForwarding {
return nil return nberrors.FormatErrorOrNil(merr)
} }
fwder.Stop() fwder.Stop()
@@ -1399,14 +1586,7 @@ func (m *Manager) DisableRouting() error {
log.Debug("forwarder stopped") log.Debug("forwarder stopped")
if m.blockRule != nil { return nberrors.FormatErrorOrNil(merr)
if err := m.deleteRouteRule(m.blockRule); err != nil {
return fmt.Errorf("delete block rule: %w", err)
}
m.blockRule = nil
}
return nil
} }
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port // RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
@@ -1460,7 +1640,8 @@ func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
} }
// traffic to our other local interfaces (not NetBird IP) - always forward // traffic to our other local interfaces (not NetBird IP) - always forward
if dstIP != m.wgIface.Address().IP { addr := m.wgIface.Address()
if dstIP != addr.IP && (!addr.IPv6.IsValid() || dstIP != addr.IPv6) {
return true return true
} }

View File

@@ -1023,7 +1023,8 @@ func BenchmarkMSSClamping(b *testing.B) {
}() }()
manager.mssClampEnabled = true manager.mssClampEnabled = true
manager.mssClampValue = 1240 manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8") dstIP := net.ParseIP("8.8.8.8")
@@ -1088,7 +1089,8 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
manager.mssClampEnabled = sc.enabled manager.mssClampEnabled = sc.enabled
if sc.enabled { if sc.enabled {
manager.mssClampValue = 1240 manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
} }
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
@@ -1141,7 +1143,8 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
}() }()
manager.mssClampEnabled = true manager.mssClampEnabled = true
manager.mssClampValue = 1240 manager.mssClampValueIPv4 = 1240
manager.mssClampValueIPv6 = 1220
srcIP := net.ParseIP("100.64.0.2") srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8") dstIP := net.ParseIP("8.8.8.8")

View File

@@ -539,53 +539,236 @@ func TestPeerACLFiltering(t *testing.T) {
} }
} }
func TestPeerACLFilteringIPv6(t *testing.T) {
localIP := netip.MustParseAddr("100.10.0.100")
localIPv6 := netip.MustParseAddr("fd00::100")
wgNet := netip.MustParsePrefix("100.10.0.0/16")
wgNetV6 := netip.MustParsePrefix("fd00::/64")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: localIP,
Network: wgNet,
IPv6: localIPv6,
IPv6Net: wgNetV6,
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
err = manager.UpdateLocalIPs()
require.NoError(t, err)
testCases := []struct {
name string
srcIP string
dstIP string
proto fw.Protocol
srcPort uint16
dstPort uint16
ruleIP string
ruleProto fw.Protocol
ruleDstPort *fw.Port
ruleAction fw.Action
shouldBeBlocked bool
}{
{
name: "IPv6: allow TCP from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: allow UDP from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: allow ICMPv6 from peer",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolICMP,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: block TCP without rule",
srcIP: "fd00::2",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "IPv6: drop rule",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 22,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{22}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "IPv6: allow all protocols",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 9999,
ruleIP: "fd00::1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
srcIP: "fd00::1",
dstIP: "fd00::100",
proto: fw.ProtocolICMP,
ruleIP: "0.0.0.0",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
}
t.Run("IPv6 implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "fd00::1", "fd00::100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.FilterInbound(packet, 0)
require.True(t, isDropped, "IPv6 packet should be dropped when no rules exist")
})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(t, err)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.FilterInbound(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped, "packet filter result mismatch")
})
}
}
func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte { func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte {
t.Helper() t.Helper()
src := net.ParseIP(srcIP)
dst := net.ParseIP(dstIP)
buf := gopacket.NewSerializeBuffer() buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ opts := gopacket.SerializeOptions{
ComputeChecksums: true, ComputeChecksums: true,
FixLengths: true, FixLengths: true,
} }
ipLayer := &layers.IPv4{ // Detect address family
Version: 4, isV6 := src.To4() == nil
TTL: 64,
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
}
var err error var err error
switch proto {
case fw.ProtocolTCP:
ipLayer.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
err = tcp.SetNetworkLayerForChecksum(ipLayer)
require.NoError(t, err)
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp)
case fw.ProtocolUDP: if isV6 {
ipLayer.Protocol = layers.IPProtocolUDP ip6 := &layers.IPv6{
udp := &layers.UDP{ Version: 6,
SrcPort: layers.UDPPort(srcPort), HopLimit: 64,
DstPort: layers.UDPPort(dstPort), SrcIP: src,
DstIP: dst,
} }
err = udp.SetNetworkLayerForChecksum(ipLayer)
require.NoError(t, err)
err = gopacket.SerializeLayers(buf, opts, ipLayer, udp)
case fw.ProtocolICMP: switch proto {
ipLayer.Protocol = layers.IPProtocolICMPv4 case fw.ProtocolTCP:
icmp := &layers.ICMPv4{ ip6.NextHeader = layers.IPProtocolTCP
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
_ = tcp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, tcp)
case fw.ProtocolUDP:
ip6.NextHeader = layers.IPProtocolUDP
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
_ = udp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, udp)
case fw.ProtocolICMP:
ip6.NextHeader = layers.IPProtocolICMPv6
icmp := &layers.ICMPv6{
TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0),
}
_ = icmp.SetNetworkLayerForChecksum(ip6)
err = gopacket.SerializeLayers(buf, opts, ip6, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ip6)
}
} else {
ip4 := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: src,
DstIP: dst,
} }
err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp)
default: switch proto {
err = gopacket.SerializeLayers(buf, opts, ipLayer) case fw.ProtocolTCP:
ip4.Protocol = layers.IPProtocolTCP
tcp := &layers.TCP{SrcPort: layers.TCPPort(srcPort), DstPort: layers.TCPPort(dstPort)}
_ = tcp.SetNetworkLayerForChecksum(ip4)
err = gopacket.SerializeLayers(buf, opts, ip4, tcp)
case fw.ProtocolUDP:
ip4.Protocol = layers.IPProtocolUDP
udp := &layers.UDP{SrcPort: layers.UDPPort(srcPort), DstPort: layers.UDPPort(dstPort)}
_ = udp.SetNetworkLayerForChecksum(ip4)
err = gopacket.SerializeLayers(buf, opts, ip4, udp)
case fw.ProtocolICMP:
ip4.Protocol = layers.IPProtocolICMPv4
icmp := &layers.ICMPv4{TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0)}
err = gopacket.SerializeLayers(buf, opts, ip4, icmp)
default:
err = gopacket.SerializeLayers(buf, opts, ip4)
}
} }
require.NoError(t, err) require.NoError(t, err)
@@ -1498,3 +1681,103 @@ func TestRouteACLSet(t *testing.T) {
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80) _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
} }
// TestRouteACLFilteringIPv6 tests IPv6 route ACL matching directly via routeACLsPass.
// Note: full FilterInbound for routed IPv6 traffic drops at the forwarder stage (IPv4-only)
// but the ACL decision itself is correct.
func TestRouteACLFilteringIPv6(t *testing.T) {
manager := setupRoutedManager(t, "10.10.0.100/16")
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
_, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: v6Dst},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
)
require.NoError(t, err)
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
fw.ProtocolALL,
nil,
nil,
fw.ActionDrop,
)
require.NoError(t, err)
tests := []struct {
name string
srcIP netip.Addr
dstIP netip.Addr
proto gopacket.LayerType
srcPort uint16
dstPort uint16
allowed bool
}{
{
name: "IPv6 TCP to allowed dest",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: true,
},
{
name: "IPv6 TCP wrong port",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 443,
allowed: false,
},
{
name: "IPv6 UDP not matched by TCP rule",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeUDP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
{
name: "IPv6 ICMPv6 matches ICMP rule via protoLayerMatches",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeICMPv6,
allowed: false,
},
{
name: "IPv6 to denied subnet",
srcIP: netip.MustParseAddr("fd00::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef:1::1"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
{
name: "IPv6 source outside allowed range",
srcIP: netip.MustParseAddr("fe80::1"),
dstIP: netip.MustParseAddr("fd00:dead:beef::80"),
proto: layers.LayerTypeTCP,
srcPort: 12345,
dstPort: 80,
allowed: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, pass := manager.routeACLsPass(tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.allowed, pass, "route ACL result mismatch")
})
}
}

View File

@@ -189,21 +189,21 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
}) })
// Call blockInvalidRouted directly multiple times // Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock) rules1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, rule1) require.NotEmpty(t, rules1)
rule2, err := manager.blockInvalidRouted(ifaceMock) rules2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, rule2) require.NotEmpty(t, rules2)
rule3, err := manager.blockInvalidRouted(ifaceMock) rules3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, rule3) require.NotEmpty(t, rules3)
// All should return the same rule // All calls should return the same v4 block rule (idempotent install).
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule") assert.Equal(t, rules1[0].ID(), rules2[0].ID(), "Second call should return same v4 rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule") assert.Equal(t, rules2[0].ID(), rules3[0].ID(), "Third call should return same v4 rule")
// Should have exactly 1 route rule // Should have exactly 1 route rule
manager.mutex.RLock() manager.mutex.RLock()

View File

@@ -535,11 +535,16 @@ func TestProcessOutgoingHooks(t *testing.T) {
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d return d
}, },
} }
@@ -638,11 +643,16 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
return d return d
}, },
} }
@@ -1048,8 +1058,8 @@ func TestMSSClamping(t *testing.T) {
}() }()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) require.Equal(t, uint16(1280-ipv4TCPHeaderMinSize), manager.mssClampValueIPv4, "IPv4 MSS clamp value should be MTU - 40")
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") require.Equal(t, uint16(1280-ipv6TCPHeaderMinSize), manager.mssClampValueIPv6, "IPv6 MSS clamp value should be MTU - 60")
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@@ -1067,7 +1077,7 @@ func TestMSSClamping(t *testing.T) {
require.Len(t, d.tcp.Options, 1, "Should have MSS option") require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS should be clamped to MTU - 40")
}) })
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
@@ -1091,7 +1101,7 @@ func TestMSSClamping(t *testing.T) {
d := parsePacket(t, packet) d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option") require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") require.Equal(t, manager.mssClampValueIPv4, actualMSS, "MSS in SYN-ACK should be clamped")
}) })
t.Run("Non-SYN packet unchanged", func(t *testing.T) { t.Run("Non-SYN packet unchanged", func(t *testing.T) {
@@ -1263,13 +1273,18 @@ func TestShouldForward(t *testing.T) {
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded) err = d.decodePacket(buf.Bytes())
require.NoError(t, err) require.NoError(t, err)
return d return d
@@ -1329,6 +1344,44 @@ func TestShouldForward(t *testing.T) {
}, },
} }
// Add IPv6 to the interface and test dual-stack cases
wgIPv6 := netip.MustParseAddr("fd00::1")
otherIPv6 := netip.MustParseAddr("fd00::2")
ifaceMock.AddressFunc = func() wgaddr.Address {
return wgaddr.Address{
IP: wgIP,
Network: netip.PrefixFrom(wgIP, 24),
IPv6: wgIPv6,
IPv6Net: netip.PrefixFrom(wgIPv6, 64),
}
}
// Re-create manager to pick up the new address with IPv6
require.NoError(t, manager.Close(nil))
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
v6Cases := []struct {
name string
dstIP netip.Addr
expected bool
description string
}{
{"v6 traffic to other address", otherIPv6, true, "should forward v6 traffic not destined to our v6 address"},
{"v6 traffic to our v6 IP", wgIPv6, false, "should not forward traffic destined to our v6 address"},
{"v4 traffic to other with v6 configured", otherIP, true, "should forward v4 traffic when v6 configured"},
{"v4 traffic to our v4 IP with v6 configured", wgIP, false, "should not forward traffic to our v4 address"},
}
for _, tt := range v6Cases {
t.Run(tt.name, func(t *testing.T) {
manager.localForwarding = true
manager.netstack = false
decoder := createTCPDecoder(8080)
result := manager.shouldForward(decoder, tt.dstIP)
require.Equal(t, tt.expected, result, tt.description)
})
}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Configure manager // Configure manager

View File

@@ -1,7 +1,8 @@
package forwarder package forwarder
import ( import (
"fmt" "net"
"strconv"
"sync/atomic" "sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
@@ -54,16 +55,23 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int var written int
for _, pkt := range pkts.AsSlice() { for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader()) data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil { if data == nil {
continue continue
} }
pktBytes := data.AsSlice() raw := pkt.NetworkHeader().View().AsSlice()
if len(raw) == 0 {
continue
}
var address tcpip.Address
if raw[0]>>4 == 6 {
address = header.IPv6(raw).DestinationAddress()
} else {
address = header.IPv4(raw).DestinationAddress()
}
address := netHeader.DestinationAddress() pktBytes := data.AsSlice()
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil { if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
e.logger.Error1("CreateOutboundPacket: %v", err) e.logger.Error1("CreateOutboundPacket: %v", err)
continue continue
@@ -114,5 +122,7 @@ type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {
// src and remote is swapped // src and remote is swapped
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) return net.JoinHostPort(i.RemoteAddress.String(), strconv.Itoa(int(i.RemotePort))) +
" → " +
net.JoinHostPort(i.LocalAddress.String(), strconv.Itoa(int(i.LocalPort)))
} }

View File

@@ -14,6 +14,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -36,25 +37,31 @@ type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection // ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map ruleIdMap sync.Map
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip tcpip.Address ip tcpip.Address
netstack bool ipv6 tcpip.Address
hasRawICMPAccess bool netstack bool
pingSemaphore chan struct{} hasRawICMPAccess bool
hasRawICMPv6Access bool
pingSemaphore chan struct{}
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol, tcp.NewProtocol,
udp.NewProtocol, udp.NewProtocol,
icmp.NewProtocol4, icmp.NewProtocol4,
icmp.NewProtocol6,
}, },
HandleLocal: false, HandleLocal: false,
}) })
@@ -73,7 +80,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), Address: tcpip.AddrFrom4(iface.Address().IP.As4()),
PrefixLen: iface.Address().Network.Bits(), PrefixLen: iface.Address().Network.Bits(),
}, },
} }
@@ -82,6 +89,19 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to add protocol address: %s", err) return nil, fmt.Errorf("failed to add protocol address: %s", err)
} }
if v6 := iface.Address().IPv6; v6.IsValid() {
v6Addr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16(v6.As16()),
PrefixLen: iface.Address().IPv6Net.Bits(),
},
}
if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("add IPv6 protocol address: %s", err)
}
}
defaultSubnet, err := tcpip.NewSubnet( defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
@@ -90,6 +110,14 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("creating default subnet: %w", err) return nil, fmt.Errorf("creating default subnet: %w", err)
} }
defaultSubnetV6, err := tcpip.NewSubnet(
tcpip.AddrFrom16([16]byte{}),
tcpip.MaskFromBytes(make([]byte, 16)),
)
if err != nil {
return nil, fmt.Errorf("creating default v6 subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil { if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err) return nil, fmt.Errorf("set promiscuous mode: %s", err)
} }
@@ -98,10 +126,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
} }
s.SetRouteTable([]tcpip.Route{ s.SetRouteTable([]tcpip.Route{
{ {Destination: defaultSubnet, NIC: nicID},
Destination: defaultSubnet, {Destination: defaultSubnetV6, NIC: nicID},
NIC: nicID,
},
}) })
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -114,7 +140,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), ip: tcpip.AddrFrom4(iface.Address().IP.As4()),
ipv6: addrFromNetipAddr(iface.Address().IPv6),
pingSemaphore: make(chan struct{}, 3), pingSemaphore: make(chan struct{}, 3),
} }
@@ -131,7 +158,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
udpForwarder := udp.NewForwarder(s, f.handleUDP) udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) // ICMP is handled directly in InjectIncomingPacket, bypassing gVisor's
// network layer. This avoids duplicate echo replies (v4) and the v6
// auto-reply bug where gVisor responds at the network layer before
// our transport handler fires.
f.checkICMPCapability() f.checkICMPCapability()
@@ -150,8 +180,30 @@ func (f *Forwarder) SetCapture(pc PacketCapture) {
} }
func (f *Forwarder) InjectIncomingPacket(payload []byte) error { func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize { if len(payload) == 0 {
return fmt.Errorf("packet too small: %d bytes", len(payload)) return fmt.Errorf("empty packet")
}
var protoNum tcpip.NetworkProtocolNumber
switch payload[0] >> 4 {
case 4:
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("IPv4 packet too small: %d bytes", len(payload))
}
if f.handleICMPDirect(payload) {
return nil
}
protoNum = ipv4.ProtocolNumber
case 6:
if len(payload) < header.IPv6MinimumSize {
return fmt.Errorf("IPv6 packet too small: %d bytes", len(payload))
}
if f.handleICMPDirect(payload) {
return nil
}
protoNum = ipv6.ProtocolNumber
default:
return fmt.Errorf("unknown IP version: %d", payload[0]>>4)
} }
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -160,11 +212,160 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
defer pkt.DecRef() defer pkt.DecRef()
if f.endpoint.dispatcher != nil { if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) f.endpoint.dispatcher.DeliverNetworkPacket(protoNum, pkt)
} }
return nil return nil
} }
// handleICMPDirect intercepts ICMP packets from raw IP payloads before they
// enter gVisor. It synthesizes the TransportEndpointID and PacketBuffer that
// the existing handlers expect, then dispatches to handleICMP/handleICMPv6.
// This bypasses gVisor's network layer which causes duplicate v4 echo replies
// and auto-replies to all v6 echo requests in promiscuous mode.
//
// Unlike gVisor's network layer, this does not validate ICMP checksums or
// reassemble IP fragments. Fragmented ICMP packets fall through to gVisor.
func parseICMPv4(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
if len(payload) < header.IPv4MinimumSize {
return 0, 0, src, dst, false
}
ip := header.IPv4(payload)
if ip.Protocol() != uint8(header.ICMPv4ProtocolNumber) {
return 0, 0, src, dst, false
}
if ip.FragmentOffset() != 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 {
return 0, 0, src, dst, false
}
ipHdrLen = int(ip.HeaderLength())
totalLen := int(ip.TotalLength())
if ipHdrLen < header.IPv4MinimumSize || ipHdrLen > totalLen || totalLen > len(payload) {
return 0, 0, src, dst, false
}
icmpLen = totalLen - ipHdrLen
if icmpLen < header.ICMPv4MinimumSize {
return 0, 0, src, dst, false
}
return ipHdrLen, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
}
func parseICMPv6(payload []byte) (ipHdrLen, icmpLen int, src, dst tcpip.Address, ok bool) {
if len(payload) < header.IPv6MinimumSize {
return 0, 0, src, dst, false
}
ip := header.IPv6(payload)
declaredLen := int(ip.PayloadLength())
hdrEnd := header.IPv6MinimumSize + declaredLen
if hdrEnd > len(payload) {
return 0, 0, src, dst, false
}
icmpStart, ok := skipIPv6ExtensionsToICMPv6(payload, ip.NextHeader(), hdrEnd)
if !ok {
return 0, 0, src, dst, false
}
icmpLen = hdrEnd - icmpStart
if icmpLen < header.ICMPv6MinimumSize {
return 0, 0, src, dst, false
}
return icmpStart, icmpLen, ip.SourceAddress(), ip.DestinationAddress(), true
}
// skipIPv6ExtensionsToICMPv6 walks the IPv6 extension-header chain starting
// after the fixed header. It advances past Hop-by-Hop, Routing, and
// Destination Options headers (which share the NextHeader+ExtLen+6+ExtLen*8
// layout) and returns the offset of the ICMPv6 payload. Fragment, ESP, AH,
// and unknown identifiers are reported as not handleable so the caller can
// defer to gVisor.
func skipIPv6ExtensionsToICMPv6(payload []byte, next uint8, hdrEnd int) (int, bool) {
off := header.IPv6MinimumSize
for {
if next == uint8(header.ICMPv6ProtocolNumber) {
return off, true
}
if !isWalkableIPv6ExtHdr(next) {
return 0, false
}
newOff, newNext, ok := advanceIPv6ExtHdr(payload, off, hdrEnd)
if !ok {
return 0, false
}
off = newOff
next = newNext
}
}
func isWalkableIPv6ExtHdr(id uint8) bool {
switch id {
case uint8(header.IPv6HopByHopOptionsExtHdrIdentifier),
uint8(header.IPv6RoutingExtHdrIdentifier),
uint8(header.IPv6DestinationOptionsExtHdrIdentifier):
return true
}
return false
}
func advanceIPv6ExtHdr(payload []byte, off, hdrEnd int) (int, uint8, bool) {
if off+8 > hdrEnd {
return 0, 0, false
}
extLen := (int(payload[off+1]) + 1) * 8
if off+extLen > hdrEnd {
return 0, 0, false
}
return off + extLen, payload[off], true
}
func (f *Forwarder) handleICMPDirect(payload []byte) bool {
if len(payload) == 0 {
return false
}
var (
ipHdrLen int
icmpLen int
srcAddr tcpip.Address
dstAddr tcpip.Address
ok bool
)
version := payload[0] >> 4
switch version {
case 4:
ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv4(payload)
case 6:
ipHdrLen, icmpLen, srcAddr, dstAddr, ok = parseICMPv6(payload)
}
if !ok {
return false
}
// Let gVisor handle ICMP destined for our own addresses natively.
// Its network-layer auto-reply is correct and efficient for local traffic.
if f.ip.Equal(dstAddr) || f.ipv6.Equal(dstAddr) {
return false
}
id := stack.TransportEndpointID{
LocalAddress: dstAddr,
RemoteAddress: srcAddr,
}
// Trim the buffer to the IP-declared length so gVisor doesn't see padding.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload[:ipHdrLen+icmpLen]),
})
defer pkt.DecRef()
if _, ok := pkt.NetworkHeader().Consume(ipHdrLen); !ok {
return false
}
if _, ok := pkt.TransportHeader().Consume(icmpLen); !ok {
return false
}
if version == 6 {
return f.handleICMPv6(id, pkt)
}
return f.handleICMP(id, pkt)
}
// Stop gracefully shuts down the forwarder // Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() { func (f *Forwarder) Stop() {
f.cancel() f.cancel()
@@ -177,11 +378,14 @@ func (f *Forwarder) Stop() {
f.stack.Wait() f.stack.Wait()
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) netip.Addr {
if f.netstack && f.ip.Equal(addr) { if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1) return netip.AddrFrom4([4]byte{127, 0, 0, 1})
} }
return addr.AsSlice() if f.netstack && f.ipv6.Equal(addr) {
return netip.IPv6Loopback()
}
return addrToNetipAddr(addr)
} }
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) { func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
@@ -215,23 +419,50 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
} }
} }
// addrFromNetipAddr converts a netip.Addr to a gvisor tcpip.Address without allocating.
func addrFromNetipAddr(addr netip.Addr) tcpip.Address {
if !addr.IsValid() {
return tcpip.Address{}
}
if addr.Is4() {
return tcpip.AddrFrom4(addr.As4())
}
return tcpip.AddrFrom16(addr.As16())
}
// addrToNetipAddr converts a gvisor tcpip.Address to netip.Addr without allocating.
func addrToNetipAddr(addr tcpip.Address) netip.Addr {
switch addr.Len() {
case 4:
return netip.AddrFrom4(addr.As4())
case 16:
return netip.AddrFrom16(addr.As16())
default:
return netip.Addr{}
}
}
// checkICMPCapability tests whether we have raw ICMP socket access at startup. // checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() { func (f *Forwarder) checkICMPCapability() {
f.hasRawICMPAccess = probeRawICMP("ip4:icmp", "0.0.0.0", f.logger)
f.hasRawICMPv6Access = probeRawICMP("ip6:ipv6-icmp", "::", f.logger)
}
func probeRawICMP(network, addr string, logger *nblog.Logger) bool {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
lc := net.ListenConfig{} lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, network, addr)
if err != nil { if err != nil {
f.hasRawICMPAccess = false logger.Debug1("forwarder: no raw %s socket access, will use ping binary fallback", network)
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback") return false
return
} }
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err) logger.Debug2("forwarder: failed to close %s capability test socket: %v", network, err)
} }
f.hasRawICMPAccess = true logger.Debug1("forwarder: raw %s socket access available", network)
f.logger.Debug("forwarder: Raw ICMP socket access available") return true
} }

View File

@@ -0,0 +1,162 @@
package forwarder
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const echoRequestSize = 8
func makeIPv6(t *testing.T, src, dst netip.Addr, nextHdr uint8, payload []byte) []byte {
t.Helper()
buf := make([]byte, header.IPv6MinimumSize+len(payload))
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(len(payload)),
TransportProtocol: 0, // overwritten below to allow any value
HopLimit: 64,
SrcAddr: tcpipAddrFromNetip(src),
DstAddr: tcpipAddrFromNetip(dst),
})
buf[6] = nextHdr
copy(buf[header.IPv6MinimumSize:], payload)
return buf
}
func tcpipAddrFromNetip(a netip.Addr) tcpip.Address {
b := a.As16()
return tcpip.AddrFrom16(b)
}
func echoRequest() []byte {
icmp := make([]byte, echoRequestSize)
icmp[0] = uint8(header.ICMPv6EchoRequest)
return icmp
}
// extHdr builds a generic IPv6 extension header (HBH/Routing/DestOpts) of the
// given total octet length (must be multiple of 8, >= 8) with the given next
// header.
func extHdr(t *testing.T, next uint8, totalLen int) []byte {
t.Helper()
require.GreaterOrEqual(t, totalLen, 8)
require.Equal(t, 0, totalLen%8)
buf := make([]byte, totalLen)
buf[0] = next
buf[1] = uint8(totalLen/8 - 1)
return buf
}
func TestParseICMPv6_NoExtensions(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), echoRequest())
off, icmpLen, _, _, ok := parseICMPv6(pkt)
require.True(t, ok)
assert.Equal(t, header.IPv6MinimumSize, off)
assert.Equal(t, echoRequestSize, icmpLen)
}
func TestParseICMPv6_SingleExtension(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
hbh := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 8)
payload := append([]byte{}, hbh...)
payload = append(payload, echoRequest()...)
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload)
off, icmpLen, _, _, ok := parseICMPv6(pkt)
require.True(t, ok)
assert.Equal(t, header.IPv6MinimumSize+8, off)
assert.Equal(t, echoRequestSize, icmpLen)
}
func TestParseICMPv6_ChainedExtensions(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
dest := extHdr(t, uint8(header.ICMPv6ProtocolNumber), 16)
rt := extHdr(t, uint8(header.IPv6DestinationOptionsExtHdrIdentifier), 8)
hbh := extHdr(t, uint8(header.IPv6RoutingExtHdrIdentifier), 8)
payload := append(append(append([]byte{}, hbh...), rt...), dest...)
payload = append(payload, echoRequest()...)
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), payload)
off, icmpLen, _, _, ok := parseICMPv6(pkt)
require.True(t, ok)
assert.Equal(t, header.IPv6MinimumSize+8+8+16, off)
assert.Equal(t, echoRequestSize, icmpLen)
}
func TestParseICMPv6_FragmentDefersToGVisor(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
pkt := makeIPv6(t, src, dst, uint8(header.IPv6FragmentExtHdrIdentifier), make([]byte, 8))
_, _, _, _, ok := parseICMPv6(pkt)
assert.False(t, ok)
}
func TestParseICMPv6_TruncatedExtension(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
// Extension claims 16 bytes but only 8 remain after the IP header.
hbh := []byte{uint8(header.ICMPv6ProtocolNumber), 1, 0, 0, 0, 0, 0, 0}
pkt := makeIPv6(t, src, dst, uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), hbh)
_, _, _, _, ok := parseICMPv6(pkt)
assert.False(t, ok)
}
func TestParseICMPv6_TruncatedICMPPayload(t *testing.T) {
src := netip.MustParseAddr("fd00::1")
dst := netip.MustParseAddr("fd00::2")
// PayloadLength claims 8 bytes of ICMPv6 but the buffer only holds 4.
pkt := makeIPv6(t, src, dst, uint8(header.ICMPv6ProtocolNumber), make([]byte, 8))
pkt = pkt[:header.IPv6MinimumSize+4]
_, _, _, _, ok := parseICMPv6(pkt)
assert.False(t, ok)
}
func TestParseICMPv4_RejectsShortIHL(t *testing.T) {
pkt := make([]byte, 28)
pkt[0] = 0x44 // version 4, IHL 4 (16 bytes - below minimum)
pkt[9] = uint8(header.ICMPv4ProtocolNumber)
header.IPv4(pkt).SetTotalLength(28)
_, _, _, _, ok := parseICMPv4(pkt)
assert.False(t, ok)
}
func TestParseICMPv4_RejectsTotalLenOverBuffer(t *testing.T) {
pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize)
ip := header.IPv4(pkt)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(len(pkt) + 16),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 64,
})
_, _, _, _, ok := parseICMPv4(pkt)
assert.False(t, ok)
}
func TestParseICMPv4_RejectsFragment(t *testing.T) {
pkt := make([]byte, header.IPv4MinimumSize+header.ICMPv4MinimumSize)
ip := header.IPv4(pkt)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(len(pkt)),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 64,
Flags: header.IPv4FlagMoreFragments,
})
_, _, _, _, ok := parseICMPv4(pkt)
assert.False(t, ok)
}

View File

@@ -13,6 +13,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
@@ -35,7 +36,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBu
} }
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice() icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond) conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), false, 100*time.Millisecond)
if err != nil { if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err) f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
return true return true
@@ -58,7 +59,7 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
defer func() { <-f.pingSemaphore }() defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess { if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes) f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, false)
} else { } else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes) f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} }
@@ -72,18 +73,23 @@ func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointI
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection. // forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection. // The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) { func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, v6 bool, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout) ctx, cancel := context.WithTimeout(f.ctx, timeout)
defer cancel() defer cancel()
network, listenAddr := "ip4:icmp", "0.0.0.0"
if v6 {
network, listenAddr = "ip6:ipv6-icmp", "::"
}
lc := net.ListenConfig{} lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, network, listenAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("create ICMP socket: %w", err) return nil, fmt.Errorf("create ICMP socket: %w", err)
} }
dstIP := f.determineDialAddr(id.LocalAddress) dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP} dst := &net.IPAddr{IP: dstIP.AsSlice()}
if _, err = conn.WriteTo(payload, dst); err != nil { if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil { if closeErr := conn.Close(); closeErr != nil {
@@ -92,17 +98,19 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
return nil, fmt.Errorf("write ICMP packet: %w", err) return nil, fmt.Errorf("write ICMP packet: %w", err)
} }
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", if f.logger.Enabled(nblog.LevelTrace) {
epID(id), icmpType, icmpCode) f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return conn, nil return conn, nil
} }
// handleICMPViaSocket handles ICMP echo requests using raw sockets. // handleICMPViaSocket handles ICMP echo requests using raw sockets for both v4 and v6.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) { func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int, v6 bool) {
sendTime := time.Now() sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second) conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, v6, 5*time.Second)
if err != nil { if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err) f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return return
@@ -113,16 +121,22 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
} }
}() }()
txBytes := f.handleEchoResponse(conn, id) txBytes := f.handleEchoResponse(conn, id, v6)
rtt := time.Since(sendTime).Round(10 * time.Microsecond) rtt := time.Since(sendTime).Round(10 * time.Microsecond)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", if f.logger.Enabled(nblog.LevelTrace) {
epID(id), icmpType, icmpCode, rtt) proto := "ICMP"
if v6 {
proto = "ICMPv6"
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int { func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID, v6 bool) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0 return 0
@@ -137,6 +151,19 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
return 0 return 0
} }
if v6 {
// Recompute checksum: the raw socket response has a checksum computed
// over the real endpoint addresses, but we inject with overlay addresses.
icmpHdr := header.ICMPv6(response[:n])
icmpHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: id.LocalAddress,
Dst: id.RemoteAddress,
}))
return f.injectICMPv6Reply(id, response[:n])
}
return f.injectICMPReply(id, response[:n]) return f.injectICMPReply(id, response[:n])
} }
@@ -150,19 +177,23 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
txPackets = 1 txPackets = 1
} }
srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := netip.AddrFrom4(id.LocalAddress.As4()) dstIp := addrToNetipAddr(id.LocalAddress)
proto := nftypes.ICMP
if srcIp.Is6() {
proto = nftypes.ICMPv6
}
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.ICMP, Protocol: proto,
// TODO: handle ipv6 SourceIP: srcIp,
SourceIP: srcIp, DestIP: dstIp,
DestIP: dstIp, ICMPType: icmpType,
ICMPType: icmpType, ICMPCode: icmpCode,
ICMPCode: icmpCode,
RxBytes: rxBytes, RxBytes: rxBytes,
TxBytes: txBytes, TxBytes: txBytes,
@@ -198,37 +229,179 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
} }
rtt := time.Since(pingStart).Round(10 * time.Microsecond) rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", if f.logger.Enabled(nblog.LevelTrace) {
epID(id), icmpType, icmpCode) f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
}
txBytes := f.synthesizeEchoReply(id, icmpData) txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// handleICMPv6 handles ICMPv6 packets from the network stack.
func (f *Forwarder) handleICMPv6(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
icmpHdr := header.ICMPv6(pkt.TransportHeader().View().AsSlice())
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
if icmpHdr.Type() == header.ICMPv6EchoRequest {
return f.handleICMPv6Echo(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For non-echo types (Destination Unreachable, Packet Too Big, etc), forward without waiting
if !f.hasRawICMPv6Access {
f.logger.Debug2("forwarder: Cannot handle ICMPv6 type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), true, 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMPv6 packet for %v: %v", epID(id), err)
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMPv6 socket: %v", err)
}
return true
}
// handleICMPv6Echo handles ICMPv6 echo requests via raw socket or ping binary fallback.
func (f *Forwarder) handleICMPv6Echo(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPv6Access {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes, true)
} else {
f.handleICMPv6ViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMPv6 rate limit exceeded for %v type %v code %v", epID(id), icmpType, icmpCode)
}
return true
}
// handleICMPv6ViaPing uses the system ping6 binary for ICMPv6 echo.
func (f *Forwarder) handleICMPv6ViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping6 failed for %v type %v code %v: %v", epID(id), icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMPv6 echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeICMPv6EchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMPv6 echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt) epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }
// synthesizeICMPv6EchoReply creates an ICMPv6 echo reply and injects it back.
func (f *Forwarder) synthesizeICMPv6EchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyHdr := header.ICMPv6(replyICMP)
replyHdr.SetType(header.ICMPv6EchoReply)
replyHdr.SetChecksum(0)
// ICMPv6Checksum computes the pseudo-header internally from Src/Dst.
// Header contains the full ICMP message, so PayloadCsum/PayloadLen are zero.
replyHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: replyHdr,
Src: id.LocalAddress,
Dst: id.RemoteAddress,
}))
return f.injectICMPv6Reply(id, replyICMP)
}
// injectICMPv6Reply wraps an ICMPv6 payload in an IPv6 header and sends to the peer.
func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv6MinimumSize)
ip := header.IPv6(ipHdr)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(len(icmpPayload)),
TransportProtocol: header.ICMPv6ProtocolNumber,
HopLimit: 64,
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMPv6 reply to peer: %v", err)
return 0
}
return len(fullPacket)
}
const (
pingBin = "ping"
ping6Bin = "ping6"
)
// buildPingCommand creates a platform-specific ping command. // buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd { // Most platforms auto-detect IPv6 from raw addresses. macOS/iOS/OpenBSD require ping6.
func buildPingCommand(ctx context.Context, target netip.Addr, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds()) timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 { if timeoutSec < 1 {
timeoutSec = 1 timeoutSec = 1
} }
isV6 := target.Is6()
timeoutStr := fmt.Sprintf("%d", timeoutSec)
switch runtime.GOOS { switch runtime.GOOS {
case "linux", "android": case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) return exec.CommandContext(ctx, pingBin, "-c", "1", "-W", timeoutStr, "-q", target.String())
case "darwin", "ios": case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String()) bin := pingBin
if isV6 {
bin = ping6Bin
}
return exec.CommandContext(ctx, bin, "-c", "1", "-t", timeoutStr, "-q", target.String())
case "freebsd": case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String()) return exec.CommandContext(ctx, pingBin, "-c", "1", "-t", timeoutStr, target.String())
case "openbsd", "netbsd": case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String()) bin := pingBin
if isV6 {
bin = ping6Bin
}
return exec.CommandContext(ctx, bin, "-c", "1", "-w", timeoutStr, target.String())
case "windows": case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String()) return exec.CommandContext(ctx, pingBin, "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default: default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String()) return exec.CommandContext(ctx, pingBin, "-c", "1", target.String())
} }
} }

View File

@@ -1,12 +1,8 @@
package forwarder package forwarder
import ( import (
"context"
"fmt"
"io"
"net" "net"
"net/netip" "strconv"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
@@ -16,7 +12,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/util/netrelay"
) )
// handleTCP is called by the TCP forwarder for new connections. // handleTCP is called by the TCP forwarder for new connections.
@@ -33,12 +31,14 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
} }
}() }()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dialAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil { if err != nil {
r.Complete(true) r.Complete(true)
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
}
return return
} }
@@ -61,64 +61,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
success = true success = true
f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
}
go f.proxyTCP(id, inConn, outConn, ep, flowID) go f.proxyTCP(id, inConn, outConn, ep, flowID)
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
// netrelay.Relay copies bidirectionally with proper half-close propagation
// and fully closes both conns before returning.
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
Logger: f.logger,
})
ctx, cancel := context.WithCancel(f.ctx) // Close the netstack endpoint after both conns are drained.
defer cancel() ep.Close()
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
var rxPackets, txPackets uint64 var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
@@ -127,21 +85,22 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
txPackets = tcpStats.SegmentsReceived.Value() txPackets = tcpStats.SegmentsReceived.Value()
} }
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
}
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := netip.AddrFrom4(id.LocalAddress.As4()) dstIp := addrToNetipAddr(id.LocalAddress)
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.TCP, Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: srcIp, SourceIP: srcIp,
DestIP: dstIp, DestIP: dstIp,
SourcePort: id.RemotePort, SourcePort: id.RemotePort,

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -125,7 +125,9 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id) delete(f.conns, idle.id)
f.Unlock() f.Unlock()
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
} }
} }
} }
@@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
_, exists := f.udpForwarder.conns[id] _, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock() f.udpForwarder.RUnlock()
if exists { if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
}
return true return true
} }
@@ -158,7 +162,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
} }
}() }()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dstAddr := net.JoinHostPort(f.determineDialAddr(id.LocalAddress).String(), strconv.Itoa(int(id.LocalPort)))
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
@@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
success = true success = true
f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
}
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
return true return true
@@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
txPackets = udpStats.PacketsReceived.Value() txPackets = udpStats.PacketsReceived.Value()
} }
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
}
f.udpForwarder.Lock() f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)
@@ -276,15 +284,14 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
// sendUDPEvent stores flow events for UDP connections // sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) { func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4()) srcIp := addrToNetipAddr(id.RemoteAddress)
dstIp := netip.AddrFrom4(id.LocalAddress.As4()) dstIp := addrToNetipAddr(id.LocalAddress)
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.UDP, Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: srcIp, SourceIP: srcIp,
DestIP: dstIp, DestIP: dstIp,
SourcePort: id.RemotePort, SourcePort: id.RemotePort,

View File

@@ -13,7 +13,6 @@ const (
ipv4HeaderMinLen = 20 ipv4HeaderMinLen = 20
ipv4ProtoOffset = 9 ipv4ProtoOffset = 9
ipv4FlagsOffset = 6 ipv4FlagsOffset = 6
ipv4DstOffset = 16
ipProtoUDP = 17 ipProtoUDP = 17
ipProtoTCP = 6 ipProtoTCP = 6
ipv4FragOffMask = 0x1fff ipv4FragOffMask = 0x1fff

View File

@@ -4,89 +4,32 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sync" "sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
) )
type localIPManager struct { // localIPSnapshot is an immutable snapshot of local IP addresses, swapped
mu sync.RWMutex // atomically so reads are lock-free.
type localIPSnapshot struct {
// fixed-size high array for upper byte of a IPv4 address ips map[netip.Addr]struct{}
ipv4Bitmap [256]*ipv4LowBitmap
} }
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address type localIPManager struct {
type ipv4LowBitmap struct { snapshot atomic.Pointer[localIPSnapshot]
bitmap [8192]uint32
} }
func newLocalIPManager() *localIPManager { func newLocalIPManager() *localIPManager {
return &localIPManager{} m := &localIPManager{}
m.snapshot.Store(&localIPSnapshot{
ips: make(map[netip.Addr]struct{}),
})
return m
} }
func (m *localIPManager) setBitmapBit(ip net.IP) { func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresses *[]netip.Addr) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
}
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -104,18 +47,19 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue continue
} }
addr, ok := netip.AddrFromSlice(ip) parsed, ok := netip.AddrFromSlice(ip)
if !ok { if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name) log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue continue
} }
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil { parsed = parsed.Unmap()
log.Debugf("process IP failed: %v", err) ips[parsed] = struct{}{}
} *addresses = append(*addresses, parsed)
} }
} }
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -123,20 +67,20 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
} }
}() }()
var newIPv4Bitmap [256]*ipv4LowBitmap ips := make(map[netip.Addr]struct{})
ipv4Set := make(map[netip.Addr]struct{}) var addresses []netip.Addr
var ipv4Addresses []netip.Addr
// 127.0.0.0/8 // loopback
newIPv4Bitmap[127] = &ipv4LowBitmap{} ips[netip.AddrFrom4([4]byte{127, 0, 0, 1})] = struct{}{}
for i := 0; i < 8192; i++ { ips[netip.IPv6Loopback()] = struct{}{}
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}
if iface != nil { if iface != nil {
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil { ip := iface.Address().IP
return err ips[ip] = struct{}{}
addresses = append(addresses, ip)
if v6 := iface.Address().IPv6; v6.IsValid() {
ips[v6] = struct{}{}
addresses = append(addresses, v6)
} }
} }
@@ -147,25 +91,24 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse // TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes. // case where an interface comes up between refreshes.
for _, intf := range interfaces { for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) processInterface(intf, ips, &addresses)
} }
} }
m.mu.Lock() m.snapshot.Store(&localIPSnapshot{ips: ips})
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses) log.Debugf("Local IP addresses: %v", addresses)
return nil return nil
} }
// IsLocalIP checks if the given IP is a local address. Lock-free on the read path.
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() { s := m.snapshot.Load()
return false
if ip.Is4() && ip.As4()[0] == 127 {
return true
} }
m.mu.RLock() _, found := s.ips[ip]
defer m.mu.RUnlock() return found
return m.checkBitmapBit(ip.AsSlice())
} }

View File

@@ -0,0 +1,72 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func setupManager(b *testing.B) *localIPManager {
b.Helper()
m := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
if err := m.UpdateLocalIPs(mock); err != nil {
b.Fatalf("UpdateLocalIPs: %v", err)
}
return m
}
func BenchmarkIsLocalIP_v4_hit(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("100.64.0.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v4_miss(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("8.8.8.8")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v6_hit(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("fd00::1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_v6_miss(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("2001:db8::1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}
func BenchmarkIsLocalIP_loopback(b *testing.B) {
m := setupManager(b)
ip := netip.MustParseAddr("127.0.0.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.IsLocalIP(ip)
}
}

View File

@@ -72,14 +72,45 @@ func TestLocalIPManager(t *testing.T) {
expected: false, expected: false,
}, },
{ {
name: "IPv6 address", name: "IPv6 address matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("fe80::1"), IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
},
testIP: netip.MustParseAddr("fd00::1"),
expected: true,
},
{
name: "IPv6 address does not match",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: netip.MustParseAddr("fd00::1"),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
},
testIP: netip.MustParseAddr("fd00::99"),
expected: false,
},
{
name: "No aliasing between similar IPs",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: netip.MustParsePrefix("192.168.1.0/24"),
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("192.168.0.17"),
expected: false, expected: false,
}, },
{
name: "IPv6 loopback",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
},
testIP: netip.MustParseAddr("::1"),
expected: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -171,90 +202,3 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
}) })
} }
} }
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap
bitmapManager := newLocalIPManager()
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := newLocalIPManager()
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := newLocalIPManager()
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@@ -53,16 +53,17 @@ var levelStrings = map[Level]string{
} }
type logMessage struct { type logMessage struct {
level Level level Level
format string argCount uint8
arg1 any format string
arg2 any arg1 any
arg3 any arg2 any
arg4 any arg3 any
arg5 any arg4 any
arg6 any arg5 any
arg7 any arg6 any
arg8 any arg7 any
arg8 any
} }
// Logger is a high-performance, non-blocking logger // Logger is a high-performance, non-blocking logger
@@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
// Enabled reports whether the given level is currently logged. Callers on the
// hot path should guard log sites with this to avoid boxing arguments into
// any when the level is off.
func (l *Logger) Enabled(level Level) bool {
return l.level.Load() >= uint32(level)
}
func (l *Logger) Error(format string) { func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
@@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) {
func (l *Logger) Error1(format string, arg1 any) { func (l *Logger) Error1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}: case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
default: default:
} }
} }
@@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) {
func (l *Logger) Error2(format string, arg1, arg2 any) { func (l *Logger) Error2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}: case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Warn2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default: default:
} }
} }
@@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) {
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelWarn) { if l.level.Load() >= uint32(LevelWarn) {
select { select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default: default:
} }
} }
@@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) { if l.level.Load() >= uint32(LevelWarn) {
select { select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default: default:
} }
} }
@@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Debug1(format string, arg1 any) { func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
select { select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}: case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
default: default:
} }
} }
@@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
func (l *Logger) Debug2(format string, arg1, arg2 any) { func (l *Logger) Debug2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
select { select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}: case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default: default:
} }
} }
@@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
select { select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default: default:
} }
} }
} }
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
// to avoid allocating an args slice on the fast path when the arg count is
// known (0-3). Args beyond 3 land on the general variadic path; callers on
// the hot path should prefer DebugN for known counts.
func (l *Logger) Debugf(format string, args ...any) {
if l.level.Load() < uint32(LevelDebug) {
return
}
switch len(args) {
case 0:
l.Debug(format)
case 1:
l.Debug1(format, args[0])
case 2:
l.Debug2(format, args[0], args[1])
case 3:
l.Debug3(format, args[0], args[1], args[2])
default:
l.sendVariadic(LevelDebug, format, args)
}
}
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
// beyond the 8-arg slot limit are dropped so callers don't produce silently
// empty log lines via uint8 wraparound in argCount.
func (l *Logger) sendVariadic(level Level, format string, args []any) {
const maxArgs = 8
n := len(args)
if n > maxArgs {
n = maxArgs
}
msg := logMessage{level: level, argCount: uint8(n), format: format}
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
for i := 0; i < n; i++ {
*slots[i] = args[i]
}
select {
case l.msgChannel <- msg:
default:
}
}
func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
default: default:
} }
} }
@@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
func (l *Logger) Trace2(format string, arg1, arg2 any) { func (l *Logger) Trace2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default: default:
} }
} }
@@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default: default:
} }
} }
@@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default: default:
} }
} }
@@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
default: default:
} }
} }
@@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
default: default:
} }
} }
@@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default: default:
} }
} }
@@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = append(*buf, levelStrings[msg.level]...) *buf = append(*buf, levelStrings[msg.level]...)
*buf = append(*buf, ' ') *buf = append(*buf, ' ')
// Count non-nil arguments for switch
argCount := 0
if msg.arg1 != nil {
argCount++
if msg.arg2 != nil {
argCount++
if msg.arg3 != nil {
argCount++
if msg.arg4 != nil {
argCount++
if msg.arg5 != nil {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
}
}
}
var formatted string var formatted string
switch argCount { switch msg.argCount {
case 0: case 0:
formatted = msg.format formatted = msg.format
case 1: case 1:

View File

@@ -11,10 +11,9 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var ( var (
errInvalidIPHeaderLength = errors.New("invalid IP header length") errInvalidIPHeaderLength = errors.New("invalid IP header length")
) )
@@ -25,10 +24,33 @@ const (
destinationPortOffset = 2 destinationPortOffset = 2
// IP address offsets in IPv4 header // IP address offsets in IPv4 header
sourceIPOffset = 12 ipv4SrcOffset = 12
destinationIPOffset = 16 ipv4DstOffset = 16
// IP address offsets in IPv6 header
ipv6SrcOffset = 8
ipv6DstOffset = 24
// IPv6 fixed header length
ipv6HeaderLen = 40
) )
// ipHeaderLen returns the IP header length based on the decoded layer type.
func ipHeaderLen(d *decoder) (int, error) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
n := int(d.ip4.IHL) * 4
if n < 20 {
return 0, errInvalidIPHeaderLength
}
return n, nil
case layers.LayerTypeIPv6:
return ipv6HeaderLen, nil
default:
return 0, fmt.Errorf("unknown IP layer: %v", d.decoded[0])
}
}
// ipv4Checksum calculates IPv4 header checksum. // ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 { func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 { if len(header) < 20 {
@@ -234,19 +256,22 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false return false
} }
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) _, dstIP := extractPacketIPs(packetData, d)
translatedIP, exists := m.getDNATTranslation(dstIP) translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists { if !exists {
return false return false
} }
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err) if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet destination: %v", err)
}
return false return false
} }
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
}
return true return true
} }
@@ -256,54 +281,115 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false return false
} }
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) srcIP, _ := extractPacketIPs(packetData, d)
originalIP, exists := m.findReverseDNATMapping(srcIP) originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists { if !exists {
return false return false
} }
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err) if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet source: %v", err)
}
return false return false
} }
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
}
return true return true
} }
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. // extractPacketIPs extracts src and dst IP addresses directly from raw packet bytes.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { func extractPacketIPs(packetData []byte, d *decoder) (src, dst netip.Addr) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
src = netip.AddrFrom4([4]byte{packetData[ipv4SrcOffset], packetData[ipv4SrcOffset+1], packetData[ipv4SrcOffset+2], packetData[ipv4SrcOffset+3]})
dst = netip.AddrFrom4([4]byte{packetData[ipv4DstOffset], packetData[ipv4DstOffset+1], packetData[ipv4DstOffset+2], packetData[ipv4DstOffset+3]})
case layers.LayerTypeIPv6:
src = netip.AddrFrom16([16]byte(packetData[ipv6SrcOffset : ipv6SrcOffset+16]))
dst = netip.AddrFrom16([16]byte(packetData[ipv6DstOffset : ipv6DstOffset+16]))
}
return src, dst
}
// rewritePacketIP replaces a source (isSource=true) or destination IP address in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, isSource bool) error {
hdrLen, err := ipHeaderLen(d)
if err != nil {
return err
}
switch d.decoded[0] {
case layers.LayerTypeIPv4:
return m.rewriteIPv4(packetData, d, newIP, hdrLen, isSource)
case layers.LayerTypeIPv6:
return m.rewriteIPv6(packetData, d, newIP, hdrLen, isSource)
default:
return fmt.Errorf("unknown IP layer: %v", d.decoded[0])
}
}
func (m *Manager) rewriteIPv4(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
if !newIP.Is4() { if !newIP.Is4() {
return ErrIPv4Only return fmt.Errorf("cannot write IPv6 address into IPv4 packet")
}
offset := ipv4DstOffset
if isSource {
offset = ipv4SrcOffset
} }
var oldIP [4]byte var oldIP [4]byte
copy(oldIP[:], packetData[ipOffset:ipOffset+4]) copy(oldIP[:], packetData[offset:offset+4])
newIPBytes := newIP.As4() newIPBytes := newIP.As4()
copy(packetData[offset:offset+4], newIPBytes[:])
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) // Recalculate IPv4 header checksum
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
binary.BigEndian.PutUint16(packetData[10:12], 0) binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) binary.BigEndian.PutUint16(packetData[10:12], ipv4Checksum(packetData[:hdrLen]))
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
// Update transport checksums incrementally
if len(d.decoded) > 1 { if len(d.decoded) > 1 {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen) m.updateICMPChecksum(packetData, hdrLen)
} }
} }
return nil
}
func (m *Manager) rewriteIPv6(packetData []byte, d *decoder, newIP netip.Addr, hdrLen int, isSource bool) error {
if !newIP.Is6() {
return fmt.Errorf("cannot write IPv4 address into IPv6 packet")
}
offset := ipv6DstOffset
if isSource {
offset = ipv6SrcOffset
}
var oldIP [16]byte
copy(oldIP[:], packetData[offset:offset+16])
newIPBytes := newIP.As16()
copy(packetData[offset:offset+16], newIPBytes[:])
// IPv6 has no header checksum, only update transport checksums
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv6:
// ICMPv6 checksum includes pseudo-header with addresses, use incremental update
m.updateICMPv6Checksum(packetData, hdrLen, oldIP[:], newIPBytes[:])
}
}
return nil return nil
} }
@@ -351,6 +437,20 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum) binary.BigEndian.PutUint16(icmpData[2:4], checksum)
} }
// updateICMPv6Checksum updates ICMPv6 checksum after address change.
// ICMPv6 uses a pseudo-header (like TCP/UDP), so incremental update applies.
func (m *Manager) updateICMPv6Checksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+4 {
return
}
checksumOffset := icmpStart + 2
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624. // incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum) sum := uint32(^oldChecksum)
@@ -403,14 +503,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
} }
// addPortRedirection adds a port redirection rule. // addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error {
m.portDNATMutex.Lock() m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock() defer m.portDNATMutex.Unlock()
rule := portDNATRule{ rule := portDNATRule{
protocol: protocol, protocol: protocol,
origPort: sourcePort, origPort: originalPort,
targetPort: targetPort, targetPort: translatedPort,
targetIP: targetIP, targetIP: targetIP,
} }
@@ -422,7 +522,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode // TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
var layerType gopacket.LayerType var layerType gopacket.LayerType
switch protocol { switch protocol {
case firewall.ProtocolTCP: case firewall.ProtocolTCP:
@@ -433,16 +533,16 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
return fmt.Errorf("unsupported protocol: %s", protocol) return fmt.Errorf("unsupported protocol: %s", protocol)
} }
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) return m.addPortRedirection(localAddr, layerType, originalPort, translatedPort)
} }
// removePortRedirection removes a port redirection rule. // removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, originalPort, translatedPort uint16) error {
m.portDNATMutex.Lock() m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock() defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 return rule.protocol == protocol && rule.origPort == originalPort && rule.targetPort == translatedPort && rule.targetIP.Compare(targetIP) == 0
}) })
if len(m.portDNATRules) == 0 { if len(m.portDNATRules) == 0 {
@@ -453,7 +553,7 @@ func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.L
} }
// RemoveInboundDNAT removes an inbound DNAT rule. // RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
var layerType gopacket.LayerType var layerType gopacket.LayerType
switch protocol { switch protocol {
case firewall.ProtocolTCP: case firewall.ProtocolTCP:
@@ -464,23 +564,23 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return fmt.Errorf("unsupported protocol: %s", protocol) return fmt.Errorf("unsupported protocol: %s", protocol)
} }
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
} }
// AddOutputDNAT delegates to the native firewall if available. // AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall") return fmt.Errorf("output DNAT not supported without native firewall")
} }
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort) return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// RemoveOutputDNAT delegates to the native firewall if available. // RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return nil return nil
} }
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort) return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
} }
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. // translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
@@ -521,7 +621,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
} }
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err) if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite port: %v", err)
}
return false return false
} }
d.dnatOrigPort = rule.origPort d.dnatOrigPort = rule.origPort
@@ -532,12 +634,12 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. // rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4 hdrLen, err := ipHeaderLen(d)
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { if err != nil {
return errInvalidIPHeaderLength return err
} }
tcpStart := ipHeaderLen tcpStart := hdrLen
if len(packetData) < tcpStart+4 { if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header") return fmt.Errorf("packet too short for TCP header")
} }
@@ -563,12 +665,12 @@ func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16,
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. // rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4 hdrLen, err := ipHeaderLen(d)
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { if err != nil {
return errInvalidIPHeaderLength return err
} }
udpStart := ipHeaderLen udpStart := hdrLen
if len(packetData) < udpStart+8 { if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header") return fmt.Errorf("packet too short for UDP header")
} }

View File

@@ -342,12 +342,17 @@ func BenchmarkDNATMemoryAllocations(b *testing.B) {
// Parse the packet fresh each time to get a clean decoder // Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}} d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded) d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err = d.decodePacket(testPacket)
assert.NoError(b, err) assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d) manager.translateOutboundDNAT(testPacket, d)
@@ -371,12 +376,17 @@ func BenchmarkDirectIPExtraction(b *testing.B) {
b.Run("decoder_extraction", func(b *testing.B) { b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison // Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}} d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded) d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err := d.decodePacket(packet)
assert.NoError(b, err) assert.NoError(b, err)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@@ -86,13 +86,18 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
d := &decoder{ d := &decoder{
decoded: []gopacket.LayerType{}, decoded: []gopacket.LayerType{},
} }
d.parser = gopacket.NewDecodingLayerParser( d.parser4 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4, layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
) )
d.parser.IgnoreUnsupported = true d.parser4.IgnoreUnsupported = true
d.parser6 = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv6,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser6.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded) err := d.decodePacket(packetData)
require.NoError(t, err) require.NoError(t, err)
return d return d
} }

View File

@@ -2,7 +2,9 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strconv"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -112,10 +114,13 @@ func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string,
} }
func (p *PacketBuilder) Build() ([]byte, error) { func (p *PacketBuilder) Build() ([]byte, error) {
ip := p.buildIPLayer() ipLayer, err := p.buildIPLayer()
pktLayers := []gopacket.SerializableLayer{ip} if err != nil {
return nil, err
}
pktLayers := []gopacket.SerializableLayer{ipLayer}
transportLayer, err := p.buildTransportLayer(ip) transportLayer, err := p.buildTransportLayer(ipLayer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -129,30 +134,43 @@ func (p *PacketBuilder) Build() ([]byte, error) {
return serializePacket(pktLayers) return serializePacket(pktLayers)
} }
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 { func (p *PacketBuilder) buildIPLayer() (gopacket.SerializableLayer, error) {
if p.SrcIP.Is4() != p.DstIP.Is4() {
return nil, fmt.Errorf("mixed address families: src=%s dst=%s", p.SrcIP, p.DstIP)
}
proto := getIPProtocolNumber(p.Protocol, p.SrcIP.Is6())
if p.SrcIP.Is6() {
return &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: proto,
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}, nil
}
return &layers.IPv4{ return &layers.IPv4{
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), Protocol: proto,
SrcIP: p.SrcIP.AsSlice(), SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(), DstIP: p.DstIP.AsSlice(),
} }, nil
} }
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { func (p *PacketBuilder) buildTransportLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
switch p.Protocol { switch p.Protocol {
case "tcp": case "tcp":
return p.buildTCPLayer(ip) return p.buildTCPLayer(ipLayer)
case "udp": case "udp":
return p.buildUDPLayer(ip) return p.buildUDPLayer(ipLayer)
case "icmp": case "icmp":
return p.buildICMPLayer() return p.buildICMPLayer(ipLayer)
default: default:
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
} }
} }
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { func (p *PacketBuilder) buildTCPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
tcp := &layers.TCP{ tcp := &layers.TCP{
SrcPort: layers.TCPPort(p.SrcPort), SrcPort: layers.TCPPort(p.SrcPort),
DstPort: layers.TCPPort(p.DstPort), DstPort: layers.TCPPort(p.DstPort),
@@ -164,24 +182,44 @@ func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableL
PSH: p.TCPState != nil && p.TCPState.PSH, PSH: p.TCPState != nil && p.TCPState.PSH,
URG: p.TCPState != nil && p.TCPState.URG, URG: p.TCPState != nil && p.TCPState.URG,
} }
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) if err := tcp.SetNetworkLayerForChecksum(nl); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
} }
return []gopacket.SerializableLayer{tcp}, nil return []gopacket.SerializableLayer{tcp}, nil
} }
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { func (p *PacketBuilder) buildUDPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
udp := &layers.UDP{ udp := &layers.UDP{
SrcPort: layers.UDPPort(p.SrcPort), SrcPort: layers.UDPPort(p.SrcPort),
DstPort: layers.UDPPort(p.DstPort), DstPort: layers.UDPPort(p.DstPort),
} }
if err := udp.SetNetworkLayerForChecksum(ip); err != nil { if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) if err := udp.SetNetworkLayerForChecksum(nl); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
} }
return []gopacket.SerializableLayer{udp}, nil return []gopacket.SerializableLayer{udp}, nil
} }
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { func (p *PacketBuilder) buildICMPLayer(ipLayer gopacket.SerializableLayer) ([]gopacket.SerializableLayer, error) {
if p.SrcIP.Is6() || p.DstIP.Is6() {
icmp := &layers.ICMPv6{
TypeCode: layers.CreateICMPv6TypeCode(p.ICMPType, p.ICMPCode),
}
if nl, ok := ipLayer.(gopacket.NetworkLayer); ok {
_ = icmp.SetNetworkLayerForChecksum(nl)
}
if p.ICMPType == layers.ICMPv6TypeEchoRequest || p.ICMPType == layers.ICMPv6TypeEchoReply {
echo := &layers.ICMPv6Echo{
Identifier: 1,
SeqNumber: 1,
}
return []gopacket.SerializableLayer{icmp, echo}, nil
}
return []gopacket.SerializableLayer{icmp}, nil
}
icmp := &layers.ICMPv4{ icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode), TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
} }
@@ -204,14 +242,17 @@ func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func getIPProtocolNumber(protocol fw.Protocol) int { func getIPProtocolNumber(protocol fw.Protocol, isV6 bool) layers.IPProtocol {
switch protocol { switch protocol {
case fw.ProtocolTCP: case fw.ProtocolTCP:
return int(layers.IPProtocolTCP) return layers.IPProtocolTCP
case fw.ProtocolUDP: case fw.ProtocolUDP:
return int(layers.IPProtocolUDP) return layers.IPProtocolUDP
case fw.ProtocolICMP: case fw.ProtocolICMP:
return int(layers.IPProtocolICMPv4) if isV6 {
return layers.IPProtocolICMPv6
}
return layers.IPProtocolICMPv4
default: default:
return 0 return 0
} }
@@ -234,7 +275,7 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace := &PacketTrace{Direction: direction} trace := &PacketTrace{Direction: direction}
// Initial packet decoding // Initial packet decoding
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false) trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
return trace return trace
} }
@@ -256,6 +297,8 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace.DestinationPort = uint16(d.udp.DstPort) trace.DestinationPort = uint16(d.udp.DstPort)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
trace.Protocol = "ICMP" trace.Protocol = "ICMP"
case layers.LayerTypeICMPv6:
trace.Protocol = "ICMPv6"
} }
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d", trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
@@ -319,6 +362,13 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
flags&conntrack.TCPFin != 0) flags&conntrack.TCPFin != 0)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq) msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
case layers.LayerTypeICMPv6:
var id, seq uint16
if len(d.icmp6.Payload) >= 4 {
id = uint16(d.icmp6.Payload[0])<<8 | uint16(d.icmp6.Payload[1])
seq = uint16(d.icmp6.Payload[2])<<8 | uint16(d.icmp6.Payload[3])
}
msg += fmt.Sprintf(" (ICMPv6 ID=%d, Seq=%d)", id, seq)
} }
return msg return msg
} }
@@ -395,7 +445,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
trace.AddResult(StageRouteACL, msg, allowed) trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder.Load() != nil { if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) m.addForwardingResult(trace, "proxy-remote", net.JoinHostPort(dstIP.String(), strconv.Itoa(int(dstPort))), true)
} }
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed) trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
@@ -415,7 +465,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false) trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace return trace
} }
@@ -434,7 +484,7 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied { if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true return true
} }
@@ -444,7 +494,7 @@ func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *de
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied { if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.decodePacket(packetData); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true return true
} }
@@ -509,7 +559,7 @@ func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *
return false return false
} }
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) srcIP, _ := extractPacketIPs(packetData, d)
translated := m.translateInboundReverse(packetData, d) translated := m.translateInboundReverse(packetData, d)
if translated { if translated {
@@ -539,7 +589,7 @@ func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d
return false return false
} }
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) _, dstIP := extractPacketIPs(packetData, d)
translated := m.translateOutboundDNAT(packetData, d) translated := m.translateOutboundDNAT(packetData, d)
if translated { if translated {

View File

@@ -119,7 +119,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
if err != nil { if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err) return fmt.Errorf("failed to parse endpoint address: %w", err)
} }
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port)) addrPort := netip.AddrPortFrom(addr.Unmap(), uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort) c.activityRecorder.UpsertAddress(peerKey, addrPort)
} }
return nil return nil

View File

@@ -2,7 +2,7 @@ package device
// TunAdapter is an interface for create tun device from external service // TunAdapter is an interface for create tun device from external service
type TunAdapter interface { type TunAdapter interface {
ConfigureInterface(address string, mtu int, dns string, searchDomains string, routes string) (int, error) ConfigureInterface(address string, addressV6 string, mtu int, dns string, searchDomains string, routes string) (int, error)
UpdateAddr(address string) error UpdateAddr(address string) error
ProtectSocket(fd int32) bool ProtectSocket(fd int32) bool
} }

View File

@@ -63,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
searchDomainsToString = "" searchDomainsToString = ""
} }
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.address.IPv6String(), int(t.mtu), dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)
return nil, err return nil, err

View File

@@ -131,23 +131,32 @@ func (t *TunDevice) Device() *device.Device {
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) if out, err := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()).CombinedOutput(); err != nil {
if out, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("add v4 address: %s: %w", string(out), err)
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
return err
} }
// dummy ipv6 so routing works // Assign a dummy link-local so macOS enables IPv6 on the tun device.
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64") // When a real overlay v6 is present, use that instead.
if out, err := cmd.CombinedOutput(); err != nil { v6Addr := "fe80::/64"
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out) if t.address.HasIPv6() {
v6Addr = t.address.IPv6String()
}
if out, err := exec.Command("ifconfig", t.name, "inet6", v6Addr).CombinedOutput(); err != nil {
log.Warnf("failed to assign IPv6 address %s, continuing v4-only: %s: %v", v6Addr, string(out), err)
t.address.ClearIPv6()
} }
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name) if out, err := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name).CombinedOutput(); err != nil {
if out, err := routeCmd.CombinedOutput(); err != nil { return fmt.Errorf("add route %s via %s: %s: %w", t.address.Network, t.name, string(out), err)
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
return err
} }
if t.address.HasIPv6() {
if out, err := exec.Command("route", "add", "-inet6", "-net", t.address.IPv6Net.String(), "-interface", t.name).CombinedOutput(); err != nil {
log.Warnf("failed to add route %s via %s, continuing v4-only: %s: %v", t.address.IPv6Net, t.name, string(out), err)
t.address.ClearIPv6()
}
}
return nil return nil
} }

View File

@@ -151,8 +151,11 @@ func (t *TunDevice) MTU() uint16 {
return t.mtu return t.mtu
} }
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { // UpdateAddr updates the device address. On iOS the tunnel is managed by the
// todo implement // NetworkExtension, so we only store the new value. The extension picks up the
// change on the next tunnel reconfiguration.
func (t *TunDevice) UpdateAddr(addr wgaddr.Address) error {
t.address = addr
return nil return nil
} }

View File

@@ -173,7 +173,7 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address) return t.link.assignAddr(&t.address)
} }
func (t *TunKernelDevice) GetNet() *netstack.Net { func (t *TunKernelDevice) GetNet() *netstack.Net {

View File

@@ -3,6 +3,7 @@ package device
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/netip"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
@@ -63,8 +64,12 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return nil, fmt.Errorf("last ip: %w", err) return nil, fmt.Errorf("last ip: %w", err)
} }
log.Debugf("netstack using address: %s", t.address.IP) addresses := []netip.Addr{t.address.IP}
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu)) if t.address.HasIPv6() {
addresses = append(addresses, t.address.IPv6)
}
log.Debugf("netstack using addresses: %v", addresses)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, addresses, dnsAddr, int(t.mtu))
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create() tunIface, net, err := t.nsTun.Create()
if err != nil { if err != nil {

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type USPDevice struct { type TunDevice struct {
name string name string
address wgaddr.Address address wgaddr.Address
port int port int
@@ -30,10 +30,10 @@ type USPDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
return &USPDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
@@ -43,7 +43,7 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu
} }
} }
func (t *USPDevice) Create() (WGConfigurer, error) { func (t *TunDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface") log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, int(t.mtu)) tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil { if err != nil {
@@ -75,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }
@@ -95,12 +95,12 @@ func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
func (t *USPDevice) Close() error { func (t *TunDevice) Close() error {
if t.configurer != nil { if t.configurer != nil {
t.configurer.Close() t.configurer.Close()
} }
@@ -115,39 +115,39 @@ func (t *USPDevice) Close() error {
return nil return nil
} }
func (t *USPDevice) WgAddress() wgaddr.Address { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *USPDevice) MTU() uint16 { func (t *TunDevice) MTU() uint16 {
return t.mtu return t.mtu
} }
func (t *USPDevice) DeviceName() string { func (t *TunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *USPDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device // Device returns the wireguard device
func (t *USPDevice) Device() *device.Device { func (t *TunDevice) Device() *device.Device {
return t.device return t.device
} }
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *USPDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
link := newWGLink(t.name) link := newWGLink(t.name)
return link.assignAddr(t.address) return link.assignAddr(&t.address)
} }
func (t *USPDevice) GetNet() *netstack.Net { func (t *TunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance // GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager { func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind return t.iceBind
} }

View File

@@ -87,7 +87,21 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
err = nbiface.Set() err = nbiface.Set()
if err != nil { if err != nil {
t.device.Close() t.device.Close()
return nil, fmt.Errorf("got error when getting setting the interface mtu: %s", err) return nil, fmt.Errorf("set IPv4 interface MTU: %s", err)
}
if t.address.HasIPv6() {
nbiface6, err := luid.IPInterface(windows.AF_INET6)
if err != nil {
log.Warnf("failed to get IPv6 interface for MTU, continuing v4-only: %v", err)
t.address.ClearIPv6()
} else {
nbiface6.NLMTU = uint32(t.mtu)
if err := nbiface6.Set(); err != nil {
log.Warnf("failed to set IPv6 interface MTU, continuing v4-only: %v", err)
t.address.ClearIPv6()
}
}
} }
err = t.assignAddr() err = t.assignAddr()
if err != nil { if err != nil {
@@ -178,8 +192,21 @@ func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
luid := winipcfg.LUID(t.nativeTunDevice.LUID()) luid := winipcfg.LUID(t.nativeTunDevice.LUID())
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) v4Prefix := t.address.Prefix()
if t.address.HasIPv6() {
v6Prefix := t.address.IPv6Prefix()
log.Debugf("adding addresses %s, %s to interface: %s", v4Prefix, v6Prefix, t.name)
if err := luid.SetIPAddresses([]netip.Prefix{v4Prefix, v6Prefix}); err != nil {
log.Warnf("failed to assign dual-stack addresses, retrying v4-only: %v", err)
t.address.ClearIPv6()
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
}
return nil
}
log.Debugf("adding address %s to interface: %s", v4Prefix, t.name)
return luid.SetIPAddresses([]netip.Prefix{v4Prefix})
} }
func (t *TunDevice) GetNet() *netstack.Net { func (t *TunDevice) GetNet() *netstack.Net {

View File

@@ -1,8 +0,0 @@
//go:build (!linux && !freebsd) || android
package device
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
return false
}

View File

@@ -1,18 +0,0 @@
package device
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
// we are currently do not use it, since it is required to add wireguard kernel support to
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
// - https://github.com/mdlayher/socket
// TODO: implement kernel space
return false
}
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true
}

View File

@@ -0,0 +1,13 @@
//go:build !linux || android
package device
// WireGuardModuleIsLoaded reports whether the kernel WireGuard module is available.
func WireGuardModuleIsLoaded() bool {
return false
}
// ModuleTunIsLoaded reports whether the tun device is available.
func ModuleTunIsLoaded() bool {
return true
}

View File

@@ -2,6 +2,7 @@ package device
import ( import (
"fmt" "fmt"
"os/exec"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -57,32 +58,32 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address wgaddr.Address) error { func (l *wgLink) assignAddr(address *wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name) link, err := freebsd.LinkByName(l.name)
if err != nil { if err != nil {
return fmt.Errorf("link by name: %w", err) return fmt.Errorf("link by name: %w", err)
} }
ip := address.IP.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits() prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen) maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits) mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) log.Infof("assign addr %s mask %s to %s interface", address.IP, mask, l.name)
err = link.AssignAddr(ip, mask) if err := link.AssignAddr(address.IP.String(), mask); err != nil {
if err != nil {
return fmt.Errorf("assign addr: %w", err) return fmt.Errorf("assign addr: %w", err)
} }
err = link.Up() if address.HasIPv6() {
if err != nil { log.Infof("assign IPv6 addr %s to %s interface", address.IPv6String(), l.name)
cmd := exec.Command("ifconfig", l.name, "inet6", address.IPv6String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %s: %v", address.IPv6String(), l.name, string(out), err)
address.ClearIPv6()
}
}
if err := link.Up(); err != nil {
return fmt.Errorf("up: %w", err) return fmt.Errorf("up: %w", err)
} }

View File

@@ -4,6 +4,8 @@ package device
import ( import (
"fmt" "fmt"
"net"
"net/netip"
"os" "os"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -92,7 +94,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address wgaddr.Address) error { func (l *wgLink) assignAddr(address *wgaddr.Address) error {
//delete existing addresses //delete existing addresses
list, err := netlink.AddrList(l, 0) list, err := netlink.AddrList(l, 0)
if err != nil { if err != nil {
@@ -110,20 +112,16 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
} }
name := l.attrs.Name name := l.attrs.Name
addrStr := address.String()
log.Debugf("adding address %s to interface: %s", addrStr, name) if err := l.addAddr(name, address.Prefix()); err != nil {
return err
addr, err := netlink.ParseAddr(addrStr)
if err != nil {
return fmt.Errorf("parse addr: %w", err)
} }
err = netlink.AddrAdd(l, addr) if address.HasIPv6() {
if os.IsExist(err) { if err := l.addAddr(name, address.IPv6Prefix()); err != nil {
log.Infof("interface %s already has the address: %s", name, addrStr) log.Warnf("failed to assign IPv6 address %s to %s, continuing v4-only: %v", address.IPv6Prefix(), name, err)
} else if err != nil { address.ClearIPv6()
return fmt.Errorf("add addr: %w", err) }
} }
// On linux, the link must be brought up // On linux, the link must be brought up
@@ -133,3 +131,22 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
return nil return nil
} }
func (l *wgLink) addAddr(ifaceName string, prefix netip.Prefix) error {
log.Debugf("adding address %s to interface: %s", prefix, ifaceName)
addr := &netlink.Addr{
IPNet: &net.IPNet{
IP: prefix.Addr().AsSlice(),
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()),
},
}
if err := netlink.AddrAdd(l, addr); os.IsExist(err) {
log.Infof("interface %s already has the address: %s", ifaceName, prefix)
} else if err != nil {
return fmt.Errorf("add addr %s: %w", prefix, err)
}
return nil
}

View File

@@ -57,7 +57,7 @@ type wgProxyFactory interface {
type WGIFaceOpts struct { type WGIFaceOpts struct {
IFaceName string IFaceName string
Address string Address wgaddr.Address
WGPort int WGPort int
WGPrivKey string WGPrivKey string
MTU uint16 MTU uint16
@@ -141,16 +141,11 @@ func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
} }
// UpdateAddr updates address of the interface // UpdateAddr updates address of the interface
func (w *WGIface) UpdateAddr(newAddr string) error { func (w *WGIface) UpdateAddr(newAddr wgaddr.Address) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
addr, err := wgaddr.ParseWGAddress(newAddr) return w.tun.UpdateAddr(newAddr)
if err != nil {
return err
}
return w.tun.UpdateAddr(addr)
} }
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist

View File

@@ -1,33 +1,28 @@
//go:build !linux && !ios && !android && !js
package iface package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) tun = device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else { } else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) tun = device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
} }
wgIFace := &WGIface{ return &WGIface{
userspaceBind: true, userspaceBind: true,
tun: tun, tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }, nil
return wgIFace, nil
} }

View File

@@ -4,23 +4,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
if netstack.IsEnabled() { if netstack.IsEnabled() {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil
@@ -28,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), tun: device.NewTunDevice(opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -1,35 +0,0 @@
//go:build !ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}

View File

@@ -1,41 +0,0 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -5,21 +5,15 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true, userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
} }

View File

@@ -4,21 +4,15 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode) // NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
relayBind := bind.NewRelayBindJS() relayBind := bind.NewRelayBindJS()
wgIface := &WGIface{ wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
userspaceBind: true, userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
} }

View File

@@ -3,44 +3,40 @@
package iface package iface
import ( import (
"fmt" "errors"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() { if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) return &WGIface{
wgIFace.userspaceBind = true tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) userspaceBind: true,
return wgIFace, nil wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}, nil
} }
if device.WireGuardModuleIsLoaded() { if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) return &WGIface{
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU) tun: device.NewKernelDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet),
return wgIFace, nil wgProxyFactory: wgproxy.NewKernelFactory(opts.WGPort, opts.MTU),
} }, nil
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
} }
return nil, fmt.Errorf("couldn't check or load tun module") if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}, nil
}
return nil, errors.New("tun module not available")
} }

View File

@@ -16,6 +16,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
) )
@@ -48,7 +49,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
opts := WGIFaceOpts{ opts := WGIFaceOpts{
IFaceName: ifaceName, IFaceName: ifaceName,
Address: addr, Address: wgaddr.MustParseWGAddress(addr),
WGPort: wgPort, WGPort: wgPort,
WGPrivKey: key, WGPrivKey: key,
MTU: DefaultMTU, MTU: DefaultMTU,
@@ -84,7 +85,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
//update WireGuard address //update WireGuard address
addr = "100.64.0.2/8" addr = "100.64.0.2/8"
err = iface.UpdateAddr(addr) err = iface.UpdateAddr(wgaddr.MustParseWGAddress(addr))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -130,7 +131,7 @@ func Test_CreateInterface(t *testing.T) {
} }
opts := WGIFaceOpts{ opts := WGIFaceOpts{
IFaceName: ifaceName, IFaceName: ifaceName,
Address: wgIP, Address: wgaddr.MustParseWGAddress(wgIP),
WGPort: 33100, WGPort: 33100,
WGPrivKey: key, WGPrivKey: key,
MTU: DefaultMTU, MTU: DefaultMTU,
@@ -174,7 +175,7 @@ func Test_Close(t *testing.T) {
opts := WGIFaceOpts{ opts := WGIFaceOpts{
IFaceName: ifaceName, IFaceName: ifaceName,
Address: wgIP, Address: wgaddr.MustParseWGAddress(wgIP),
WGPort: wgPort, WGPort: wgPort,
WGPrivKey: key, WGPrivKey: key,
MTU: DefaultMTU, MTU: DefaultMTU,
@@ -219,7 +220,7 @@ func TestRecreation(t *testing.T) {
opts := WGIFaceOpts{ opts := WGIFaceOpts{
IFaceName: ifaceName, IFaceName: ifaceName,
Address: wgIP, Address: wgaddr.MustParseWGAddress(wgIP),
WGPort: wgPort, WGPort: wgPort,
WGPrivKey: key, WGPrivKey: key,
MTU: DefaultMTU, MTU: DefaultMTU,
@@ -291,7 +292,7 @@ func Test_ConfigureInterface(t *testing.T) {
} }
opts := WGIFaceOpts{ opts := WGIFaceOpts{
IFaceName: ifaceName, IFaceName: ifaceName,
Address: wgIP, Address: wgaddr.MustParseWGAddress(wgIP),
WGPort: wgPort, WGPort: wgPort,
WGPrivKey: key, WGPrivKey: key,
MTU: DefaultMTU, MTU: DefaultMTU,
@@ -347,7 +348,7 @@ func Test_UpdatePeer(t *testing.T) {
opts := WGIFaceOpts{ opts := WGIFaceOpts{
IFaceName: ifaceName, IFaceName: ifaceName,
Address: wgIP, Address: wgaddr.MustParseWGAddress(wgIP),
WGPort: 33100, WGPort: 33100,
WGPrivKey: key, WGPrivKey: key,
MTU: DefaultMTU, MTU: DefaultMTU,
@@ -417,7 +418,7 @@ func Test_RemovePeer(t *testing.T) {
opts := WGIFaceOpts{ opts := WGIFaceOpts{
IFaceName: ifaceName, IFaceName: ifaceName,
Address: wgIP, Address: wgaddr.MustParseWGAddress(wgIP),
WGPort: 33100, WGPort: 33100,
WGPrivKey: key, WGPrivKey: key,
MTU: DefaultMTU, MTU: DefaultMTU,
@@ -482,7 +483,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer1 := WGIFaceOpts{ optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName, IFaceName: peer1ifaceName,
Address: peer1wgIP.String(), Address: wgaddr.MustParseWGAddress(peer1wgIP.String()),
WGPort: peer1wgPort, WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(), WGPrivKey: peer1Key.String(),
MTU: DefaultMTU, MTU: DefaultMTU,
@@ -522,7 +523,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer2 := WGIFaceOpts{ optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName, IFaceName: peer2ifaceName,
Address: peer2wgIP.String(), Address: wgaddr.MustParseWGAddress(peer2wgIP.String()),
WGPort: peer2wgPort, WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(), WGPrivKey: peer2Key.String(),
MTU: DefaultMTU, MTU: DefaultMTU,

View File

@@ -13,7 +13,7 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address netip.Addr addresses []netip.Addr
dnsAddress netip.Addr dnsAddress netip.Addr
mtu int mtu int
listenAddress string listenAddress string
@@ -22,9 +22,9 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, addresses []netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, addresses: addresses,
dnsAddress: dnsAddress, dnsAddress: dnsAddress,
mtu: mtu, mtu: mtu,
listenAddress: listenAddress, listenAddress: listenAddress,
@@ -33,7 +33,7 @@ func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.A
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{t.address}, t.addresses,
[]netip.Addr{t.dnsAddress}, []netip.Addr{t.dnsAddress},
t.mtu) t.mtu)
if err != nil { if err != nil {

View File

@@ -3,12 +3,18 @@ package wgaddr
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"github.com/netbirdio/netbird/shared/netiputil"
) )
// Address WireGuard parsed address // Address WireGuard parsed address
type Address struct { type Address struct {
IP netip.Addr IP netip.Addr
Network netip.Prefix Network netip.Prefix
// IPv6 overlay address, if assigned.
IPv6 netip.Addr
IPv6Net netip.Prefix
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
@@ -23,6 +29,60 @@ func ParseWGAddress(address string) (Address, error) {
}, nil }, nil
} }
func (addr Address) String() string { // HasIPv6 reports whether a v6 overlay address is assigned.
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits()) func (addr Address) HasIPv6() bool {
return addr.IPv6.IsValid()
}
func (addr Address) String() string {
return addr.Prefix().String()
}
// IPv6String returns the v6 address in CIDR notation, or empty string if none.
func (addr Address) IPv6String() string {
if !addr.HasIPv6() {
return ""
}
return addr.IPv6Prefix().String()
}
// Prefix returns the v4 host address with its network prefix length (e.g. 100.64.0.1/16).
func (addr Address) Prefix() netip.Prefix {
return netip.PrefixFrom(addr.IP, addr.Network.Bits())
}
// IPv6Prefix returns the v6 host address with its network prefix length, or a zero prefix if none.
func (addr Address) IPv6Prefix() netip.Prefix {
if !addr.HasIPv6() {
return netip.Prefix{}
}
return netip.PrefixFrom(addr.IPv6, addr.IPv6Net.Bits())
}
// SetIPv6FromCompact decodes a compact prefix (5 or 17 bytes) and sets the IPv6 fields.
// Returns an error if the bytes are invalid. A nil or empty input is a no-op.
//
//nolint:recvcheck
func (addr *Address) SetIPv6FromCompact(raw []byte) error {
if len(raw) == 0 {
return nil
}
prefix, err := netiputil.DecodePrefix(raw)
if err != nil {
return fmt.Errorf("decode v6 overlay address: %w", err)
}
if !prefix.Addr().Is6() {
return fmt.Errorf("expected IPv6 address, got %s", prefix.Addr())
}
addr.IPv6 = prefix.Addr()
addr.IPv6Net = prefix.Masked()
return nil
}
// ClearIPv6 removes the IPv6 overlay address, leaving only v4.
//
//nolint:recvcheck
func (addr *Address) ClearIPv6() {
addr.IPv6 = netip.Addr{}
addr.IPv6Net = netip.Prefix{}
} }

View File

@@ -0,0 +1,10 @@
package wgaddr
// MustParseWGAddress parses and returns a WG Address, panicking on error.
func MustParseWGAddress(address string) Address {
a, err := ParseWGAddress(address)
if err != nil {
panic(err)
}
return a
}

View File

@@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -196,18 +196,25 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
} }
} }
// fakeAddress returns a fake address that is used to as an identifier for the peer. // fakeAddress returns a fake address that is used as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. // The fake address is in the format of 127.1.x.x where x.x is derived from the
// last two bytes of the peer address (works for both IPv4 and IPv6).
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
octets := strings.Split(peerAddress.IP.String(), ".") if peerAddress == nil {
if len(octets) != 4 { return nil, fmt.Errorf("nil peer address")
return nil, fmt.Errorf("invalid IP format") }
if peerAddress.Port < 0 || peerAddress.Port > 65535 {
return nil, fmt.Errorf("invalid UDP port: %d", peerAddress.Port)
} }
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])) addr, ok := netip.AddrFromSlice(peerAddress.IP)
if err != nil { if !ok {
return nil, fmt.Errorf("parse new IP: %w", err) return nil, fmt.Errorf("invalid IP format")
} }
addr = addr.Unmap()
raw := addr.As16()
fakeIP := netip.AddrFrom4([4]byte{127, 1, raw[14], raw[15]})
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil return &netipAddr, nil

View File

@@ -280,6 +280,43 @@ CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}" CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
SectionEnd SectionEnd
# Install the Microsoft Edge WebView2 runtime if it isn't already present.
# Macro adapted from Wails3's NSIS template (wails_tools.nsh): a registry
# probe followed by a silent install of the embedded evergreen bootstrapper.
# The MicrosoftEdgeWebview2Setup.exe payload is staged next to this script
# by the sign-pipelines build step (`wails3 generate webview2bootstrapper`).
!macro nb.webview2runtime
SetRegView 64
# Per-machine install marker — populated when the runtime ships with
# Edge or has been installed by an admin previously.
ReadRegStr $0 HKLM "SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
${If} $0 != ""
Goto webview2_ok
${EndIf}
# Per-user fallback for HKCU installs.
ReadRegStr $0 HKCU "Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
${If} $0 != ""
Goto webview2_ok
${EndIf}
SetDetailsPrint both
DetailPrint "Installing: WebView2 Runtime"
SetDetailsPrint listonly
InitPluginsDir
CreateDirectory "$pluginsdir\webview2bootstrapper"
SetOutPath "$pluginsdir\webview2bootstrapper"
File "MicrosoftEdgeWebview2Setup.exe"
ExecWait '"$pluginsdir\webview2bootstrapper\MicrosoftEdgeWebview2Setup.exe" /silent /install'
SetDetailsPrint both
webview2_ok:
!macroend
Section -WebView2
!insertmacro nb.webview2runtime
SectionEnd
Section -Post Section -Post
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start'
@@ -326,9 +363,9 @@ DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
!if ${ARCH} == "amd64" # Legacy: pre-Wails installs shipped opengl32.dll (Mesa3D for Fyne); remove
# any leftover copy on uninstall so old upgrades don't leave it behind.
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
!endif
DetailPrint "Removing application directory..." DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"

View File

@@ -5,7 +5,6 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strconv" "strconv"
"sync" "sync"
@@ -19,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
) )
var ErrSourceRangesEmpty = errors.New("sources range is empty") var ErrSourceRangesEmpty = errors.New("sources range is empty")
@@ -105,6 +105,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
newRulePairs := make(map[id.RuleID][]firewall.Rule) newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string) ipsetByRuleSelectors := make(map[string]string)
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
// the missing deny. Currently we accumulate errors and continue.
var merr *multierror.Error
for _, r := range rules { for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet // if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it // it's IP address can be used in the ipset for firewall manager which supports it
@@ -117,9 +121,8 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
} }
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil { if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err) merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
d.rollBack(newRulePairs) continue
break
} }
if len(rulePair) > 0 { if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair d.peerRulesPairs[pairID] = rulePair
@@ -127,6 +130,10 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
} }
} }
if merr != nil {
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
}
for pairID, rules := range d.peerRulesPairs { for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok { if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules { for _, rule := range rules {
@@ -216,9 +223,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule, r *mgmProto.FirewallRule,
ipsetName string, ipsetName string,
) (id.RuleID, []firewall.Rule, error) { ) (id.RuleID, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP) ip, err := extractRuleIP(r)
if ip == nil { if err != nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") return "", nil, err
} }
protocol, err := convertToFirewallProtocol(r.Protocol) protocol, err := convertToFirewallProtocol(r.Protocol)
@@ -289,13 +296,13 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
func (d *DefaultManager) addInRules( func (d *DefaultManager) addInRules(
id []byte, id []byte,
ip net.IP, ip netip.Addr,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName) rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -305,7 +312,7 @@ func (d *DefaultManager) addInRules(
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
id []byte, id []byte,
ip net.IP, ip netip.Addr,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
@@ -315,7 +322,7 @@ func (d *DefaultManager) addOutRules(
return nil, nil return nil, nil
} }
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName) rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -323,9 +330,9 @@ func (d *DefaultManager) addOutRules(
return rule, nil return rule, nil
} }
// getPeerRuleID() returns unique ID for the rule based on its parameters. // getPeerRuleID returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID( func (d *DefaultManager) getPeerRuleID(
ip net.IP, ip netip.Addr,
proto firewall.Protocol, proto firewall.Protocol,
direction int, direction int,
port *firewall.Port, port *firewall.Port,
@@ -344,15 +351,25 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
} }
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
log.Debugf("rollback ACL to previous state") // extractRuleIP extracts the peer IP from a firewall rule.
for _, rules := range newRulePairs { // If sourcePrefixes is populated (new management), decode the first entry and use its address.
for _, rule := range rules { // Otherwise fall back to the deprecated PeerIP string field (old management).
if err := d.firewall.DeletePeerRule(rule); err != nil { func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err) if len(r.SourcePrefixes) > 0 {
} addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
if err != nil {
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
} }
return addr.Unmap(), nil
} }
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
}
return addr.Unmap(), nil
} }
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {

View File

@@ -321,6 +321,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.DisableFirewall, a.config.DisableFirewall,
a.config.BlockLANAccess, a.config.BlockLANAccess,
a.config.BlockInbound, a.config.BlockInbound,
a.config.DisableIPv6,
a.config.LazyConnectionEnabled, a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot, a.config.EnableSSHRoot,
a.config.EnableSSHSFTP, a.config.EnableSSHSFTP,

Some files were not shown because too many files have changed in this diff Show More