Compare commits

..

267 Commits

Author SHA1 Message Date
Zoltán Papp
dfee5252a3 client/ui: open main window on tray left-click on Linux
KDE Plasma routes a tray left-click to the SNI Activate method (right-click
opens the context menu), but NetBird wired no Activate action, so on KDE a
left-click appeared completely dead while only right-click surfaced the menu.

Bind the Linux tray OnClick handler to ShowWindow(). OpenMenu() is not an
option on Linux: Wails v3 leaves linuxSystemTray.openMenu unimplemented (it
only logs), so left-click→OpenMenu would still do nothing on KDE. ShowWindow()
is the same call Windows already runs from its double-click handler, and it
does not reproduce the macOS OpenMenu freeze (c77e5cef8) — that came from
NSStatusItem's blocking embedded menu loop, whereas Show/Focus return
immediately.

Split the Linux click handler into its own tray_click_linux.go and narrow the
macOS no-op bindTrayClick build tag accordingly. The context menu stays on
right-click on every host. On hosts that already open the menu on left-click
natively (GNOME Shell + AppIndicator) left-click now opens the window instead;
the menu remains on right-click.
2026-06-01 22:04:49 +02:00
Zoltán Papp
072d789463 client/ui: retry XEmbed-tray probe before claiming SNI watcher
The XEmbed tray (panel) can come up after the autostarted UI on minimal
WMs, so the single startup probe added in #6320 could miss a tray that
appears a second or two later, leaving the icon silently absent. Re-probe
for a ~10s grace period in a goroutine, claiming the watcher as soon as a
tray shows up; back off cleanly if none ever appears (headless/Wayland).
2026-06-01 21:55:04 +02:00
Pascal Fischer
8d05fe07bf Fix watcher registration on wayland (#6320)
* Fix hover label on linux

* Fix watcher registration on wayland
2026-06-01 21:48:39 +02:00
Zoltán Papp
61da51ed2e client/peer: don't fan out unchanged management/signal state
MarkManagement{Connected,Disconnected} and MarkSignal{Connected,
Disconnected} fired notifyStateChange unconditionally. The connect
goroutine re-marks the same state on every health-check cycle, so a
steady "connected -> connected" re-mark pushed a full SubscribeStatus
snapshot to every consumer each time — flooding the desktop UI (and its
tray) with identical Connected snapshots.

Guard each with an early return when neither the state nor the error
actually changed, so only real transitions wake SubscribeStatus
subscribers. The notifier already deduplicates, so collapsing both calls
under one guard is safe.
2026-06-01 21:11:32 +02:00
Zoltán Papp
60c86c63aa client/server: throttle and single-flight health probes
Status(GetFullPeerStatus=true) RPCs trigger a full health probe
(network round-trips to management, signal and the relays). The
desktop UI issues these frequently and concurrently, and a burst of
parallel Get() calls each fired its own probe — the lastProbe guard
was unprotected against concurrent access and only advanced when every
component was healthy, so a sustained unhealthy state (e.g. relay down)
disabled the throttle entirely and let every call re-probe.

Extract the throttle/single-flight policy into probeThrottle:
  - single-flight: only one probe runs at a time; concurrent callers
    that piled up while it ran share its result instead of each
    launching another, even when that probe failed.
  - throttle: lastOK only advances on a fully successful probe, so
    while anything is unhealthy callers keep probing frequently and
    notice recovery quickly (preserved from the original design).

RunHealthProbes now takes a context so a caller that gives up (e.g. a
Status RPC whose client disconnected) cancels the in-flight STUN/TURN
probe instead of letting it run to its per-component timeout. The
engine's own lifetime ctx still applies independently.
2026-06-01 21:07:12 +02:00
Zoltán Papp
4cee07bef5 client/ui: use monochrome tray icons on Linux
Linux now shows monochrome (black/white silhouette) tray icons instead
of the colored orange PNGs, matching the macOS template look. Since
Wails' Linux SNI backend ignores SetDarkModeIcon (its setDarkModeIcon
just calls setIcon, last-write-wins) and the SNI spec carries no panel
light/dark hint, the panel color scheme is detected in-process and the
black-vs-white silhouette is chosen in iconForState, pushed via a single
SetIcon.

Detection order (tray_theme_linux.go): freedesktop Settings portal
(org.freedesktop.appearance/color-scheme) -> GTK_THEME env (:dark
suffix) -> default dark. A SettingChanged subscription repaints live on
theme flips. macOS (template) and Windows (colored) paths are unchanged.

Icons are 48x48 mono PNGs (3% margin) generated from the macOS
silhouettes.
2026-06-01 20:23:46 +02:00
Zoltán Papp
5bebecc427 ui: disable WebKit sandbox when unprivileged userns are blocked
WebKitGTK crashes at startup when its bubblewrap sandbox can't create an
unprivileged user namespace (bwrap: setting up uid map: Permission denied
-> Failed to fully launch dbus-proxy -> panic in webkit_web_view_load_uri).
This happens in containers/VMs and on Ubuntu 24.04+ where AppArmor
restricts unprivileged user namespaces. Detect that the kernel blocks
userns via procfs and set WEBKIT_DISABLE_SANDBOX_THIS_IS_DANGEROUS so the
UI stays usable; honor an explicit user override either way.
2026-06-01 20:06:44 +02:00
Zoltán Papp
3dbd96b172 Add Version service exposing GUI version to frontend 2026-06-01 19:23:41 +02:00
Eduard Gert
6fe35cae83 Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor 2026-06-01 17:50:38 +02:00
Eduard Gert
88bd1f91a8 update session expire dialog to account for hours, days etc. 2026-06-01 17:50:21 +02:00
Eduard Gert
acfd680560 remove icons from profile dropdown 2026-06-01 17:37:05 +02:00
Pascal Fischer
0b8aae4566 Fix hover label on linux (#6318) 2026-06-01 17:36:39 +02:00
Eduard Gert
daf9a74d8f update dialog error for connection switch 2026-06-01 17:18:31 +02:00
Eduard Gert
8af90e40d5 Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor
# Conflicts:
#	client/ui/frontend/src/contexts/SettingsContext.tsx
2026-06-01 17:15:38 +02:00
Eduard Gert
3f989f69cb update about gui version, add mock data, truncate some strings 2026-06-01 17:14:30 +02:00
Zoltan Papp
53d43980ad fix(ui): keep main window closable after a native dialog on Windows (#6319)
A native Windows MessageBox attached to a parent window disables that
window (WS_DISABLED) for its lifetime and re-enables it on dismissal.
When the parent is the main window — whose WindowClosing hook hides
instead of closes — the enable/hide sequence races and leaves the window
unable to process its close (X) button afterwards, so e.g. a rejected
login error dialog left the main window stuck open.

Route all native dialogs through src/lib/dialogs.ts, which forces
Detached: true on Windows (NULL owner, no window ever disabled) and is a
no-op on macOS/Linux (keeps the attached sheet-style presentation).
2026-06-01 17:03:53 +02:00
Zoltán Papp
49df24b18c escape ampersand in tray menu labels on Windows
Win32 swallows a lone & in an MFT_STRING menu item as the mnemonic
prefix, so "Help & Support" rendered as "Help  Support". Add a
build-tagged menuLabel() helper that doubles & to && on Windows and is
the identity on macOS/Linux (which render & literally), and apply it to
the About submenu label.
2026-06-01 16:44:18 +02:00
Zoltán Papp
c5611dd766 open main window on tray double-click on Windows
Wire the Windows systray's double-click to ShowWindow(), matching the
Windows-native convention for tray apps. The Wails v3 systray dispatches
WM_LBUTTONDBLCLK to the doubleClickHandler, so OnDoubleClick fires; left-
and right-click continue to open the menu. macOS/Linux are unchanged.
2026-06-01 16:41:44 +02:00
Eduard Gert
a4ad93008b add skeleton to launch netbird ui at login and own context 2026-06-01 14:57:31 +02:00
Eduard Gert
101e04f9fb update wails alpha from .95 to .97 2026-06-01 11:19:02 +02:00
Eduard Gert
710d5c6182 apply different window width on windows 2026-06-01 11:07:33 +02:00
Eduard Gert
7538a9a133 update window width and dialogs top padding on macos 2026-06-01 09:51:59 +02:00
Zoltan Papp
5f7657b95e Merge branch 'main' into ui-refactor 2026-05-31 12:33:50 +02:00
Zoltan Papp
27873866c2 Revert needs login icon 2026-05-31 04:34:17 +02:00
Zoltan Papp
18348e1491 client+ui: remove SSO handoff flicker and clean up abandoned login via context
Two follow-ups to the "hold NeedsLogin during the SSO browser wait" change.
Both target the visible state churn the tray showed during the auto-login
handoff (Connect / profile-switch lands on NeedsLogin -> the UI's startLogin
kicks off the SSO flow) and the broken recovery after the user dismisses the
browser-login popup with the window's X.

Background
----------
When a connect attempt lands on NeedsLogin, the UI's startLogin() drives the
SSO flow: Connection.Login() -> (NeedsSSOLogin) open the browser-login popup
-> Connection.WaitSSOLogin() blocks until the browser leg completes. The tray
and the React status page both paint the raw daemon status, so any transient
state the daemon publishes during this handoff is visible as a flicker.

Previously the handoff churned the daemon status through
  NeedsLogin -> Idle -> Connecting -> NeedsLogin
which read as a flicker on the tray icon and the status dot. Two distinct
sources produced the two intermediate states:

  * Idle       came from the UI's defensive cli.Down() at the top of
                Connection.Login (services/connection.go): it tore the engine
                down before every login to dislodge a possibly-parked
                WaitSSOLogin, emitting a StatusIdle on the way.
  * Connecting  came from server.go Login() unconditionally setting
                StatusConnecting before deciding whether the request is an
                SSO flow (which immediately returns NeedsLogin) or a
                setup-key flow (which actually dials Management).

Changes
-------
1. server.go Login(): only set StatusConnecting on the setup-key path, where
   we are about to dial Management with the key and the Connecting paint is
   meaningful. The SSO path returns NeedsLogin and parks on the browser leg,
   so it no longer flashes Connecting first. Removes the Connecting blip.

2. services/connection.go Login(): drop the pre-Login cli.Down(). The daemon
   already dislodges a pending WaitSSOLogin at Login entry (actCancel), and an
   abandoned browser leg is now torn down by cancelling the WaitSSOLogin RPC
   (see 3/4). Removing the Down removes the Idle blip on every login.

3. MainConnectionStatusSwitch.tsx startLogin(): on cancel (the browser-login
   popup's Cancel button or its window X, both routed through
   EventBrowserLoginCancel), cancel the in-flight WaitSSOLogin gRPC call via
   waitPromise.cancel() instead of issuing a heavy Connection.Down(). The
   daemon ties the wait to this call's context, so cancelling the call ends
   the wait cleanly with no engine teardown and no Idle paint.

4. server.go WaitSSOLogin(): when the wait unblocks with context.Canceled and
   the cancellation came from our caller (callerCtx.Err() != nil — the client
   cancelled the RPC or went away), clear the cached oauthAuthFlow so a fresh
   Login starts a new device code instead of reusing the abandoned one. The
   entry NeedsLogin stays in place, so a reattaching client still shows the
   login affordance. An internal abort (actCancel fired by a newer
   Login/WaitSSOLogin while our callerCtx is still live) is left untouched so
   the new owner's flow is not clobbered.

Effect
------
The auto-login handoff now goes Connected -> Connecting -> NeedsLogin and
holds, with no Idle/Connecting flicker in between. Dismissing the browser-login
popup with X now recovers the same way as the Cancel button: the WaitSSOLogin
RPC is cancelled, the stale OAuth flow is cleared, and the next connect opens a
fresh browser-login window instead of getting stuck.
2026-05-31 04:26:15 +02:00
Zoltan Papp
9569ac2081 internal: add temporary debug log for state.Set value and caller 2026-05-31 03:45:35 +02:00
Zoltan Papp
0b484133b2 client: hold NeedsLogin during SSO browser wait and tie it to the caller
WaitSSOLogin set StatusConnecting on entry and ran the browser wait on
rootCtx. If the client that drove the login went away mid-wait (UI restart,
CLI Ctrl+C), the wait orphaned on rootCtx until the OAuth device-code window
expired, and the daemon stayed stuck reporting Connecting — a reattaching
client saw a spinner that never resolved instead of a login prompt.

Hold StatusNeedsLogin for the whole browser wait (also in the Login
cached-flow path) so any client attaching mid-wait reads 'login required',
and bridge the wait to callerCtx so a departing client cancels it. On that
cancel the defer leaves NeedsLogin in place, so the next client shows the
login affordance instead of a stale Connecting.
2026-05-31 03:45:13 +02:00
Zoltan Papp
5df570feb8 ui: localize tray status labels for connect/login states
StatusLabel only mapped Idle and DaemonUnavailable, so Connected,
Connecting, NeedsLogin, LoginFailed and SessionExpired leaked the raw
daemon enum into the tray menu — untranslated in de/hu. Map all five to
tray.status.* keys (added in en/de/hu); keep the raw-enum default as a
fallback for any future status.
2026-05-31 03:15:34 +02:00
Zoltan Papp
ed4d823755 ui: auto-trigger browser login when profile switch lands on NeedsLogin
Add a second, longer-lived switchLoginWatch flag alongside switchInProgress
in DaemonFeed. Suppression still clears on the first Connecting push from the
new Up, but the login watcher survives past it to catch the eventual
NeedsLogin / LoginFailed / SessionExpired terminal and emit EventTriggerLogin,
so the React orchestrator opens the browser-login flow without a second
Connect click. shouldSuppress becomes consumeForSwitch, returning both the
suppress and triggerLogin signals. CancelProfileSwitch disarms the watch so
an aborted switch does not pop a login window.
2026-05-31 03:15:25 +02:00
Zoltan Papp
cedfa2ebf7 peer: add temporary debug log for notifyStateChange caller 2026-05-31 02:59:14 +02:00
Zoltan Papp
8b03c96851 ui: add launch-at-login (autostart) toggle for the UI
Add an Autostart Wails service wrapping app.Autostart and a toggle in
the General settings tab. The OS login-item registration is the single
source of truth (nothing mirrored to the preferences file). Affects the
graphical UI only, not the daemon. The toggle hides itself on platforms
where autostart is unsupported.
2026-05-31 02:01:02 +02:00
Zoltan Papp
b830a45333 ui: request macOS notification authorization on startup
Without an explicit authorization request the macOS notification center
keeps the app at .notDetermined and silently drops every toast. Request
it from the ApplicationStarted hook (after the notifier's Startup has
initialised the delegate), off the main goroutine since the call blocks
until the user responds. Linux/Windows notifier stubs report authorized,
so this is a no-op there.
2026-05-31 01:34:56 +02:00
Eduard Gert
b0d8ac6489 fix auto size detection 2026-05-29 17:21:45 +02:00
Eduard Gert
558769e671 localize window titles, fix size for windows (and other platforms?) 2026-05-29 16:58:08 +02:00
Zoltán Papp
fb6138a3ba tray: open menu on left-click on Windows
Wails v3 does not auto-show the tray menu on left-click on Windows — its
default left-click handler only logs and does nothing visible, so only
right-click opened the menu. macOS (NSStatusItem) and Linux
(StatusNotifierItem host) give us click→menu natively.

Add a build-tag-split bindTrayClick: the Windows variant wires
OnClick→OpenMenu (the same menu.ShowAt path right-click uses), while the
macOS/Linux variant stays a no-op — binding OnClick→OpenMenu on macOS
freezes the tray via NSStatusItem's blocking mouseDown on the main GCD
queue (the reason commit c77e5cef8 reverted the earlier wiring).
2026-05-29 16:02:41 +02:00
Zoltán Papp
b111c38b7c Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor 2026-05-29 15:54:57 +02:00
Zoltán Papp
f54121ebfa Merge branch 'main' into ui-refactor
# Conflicts:
#	.github/workflows/golang-test-darwin.yml
#	.github/workflows/release.yml
2026-05-29 15:51:43 +02:00
Eduard Gert
122d172f33 Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor 2026-05-29 15:40:41 +02:00
Eduard Gert
0b19a99693 update ui for win11 2026-05-29 15:40:34 +02:00
Zoltán Papp
0309f992ad ssh/server: keep testing dep out of the wasm build
test.go is a non-_test.go file so its exported StartTestServer helper is
visible to the ssh/proxy and ssh/client external test packages. That drags
the testing/flag/regexp chain into every build that links ssh/server,
including the wasm client (via the engine). Gate the file with //go:build
!js: native test packages still see the helper, wasm drops the dependency.
2026-05-29 15:37:04 +02:00
Zoltán Papp
1ad2d90d3b client: keep sessionwatch out of the wasm build
The wasm client never runs the engine's session-warning flow, so linking
the full sessionwatch package (timers, event composition) only bloats the
binary. Put the watcher behind a sessionDeadlineWatcher interface and split
the constructor by build tag: !js wires the real sessionwatch.Watcher, js
gets a no-op stub that still mirrors the deadline into the status recorder
(so the Status snapshot stays correct) but drops the timers. Removes the
sessionwatch package from the wasm dependency graph.
2026-05-29 15:26:21 +02:00
Zoltán Papp
93a1547871 ci: skip non-English UI locales in codespell
de/hu translations contain real foreign words (Sie, oder, ist) that
codespell flags as misspellings. Only en/common.json is the spell-check
source of truth; add each new locale dir to the skip list as languages land.
2026-05-29 15:07:53 +02:00
Zoltán Papp
04ab9b5bad ui: split main() into focused setup helpers
Extract parseFlagsAndInitLog, newApplication, buildI18n, registerServices,
and newMainWindow so main() stays under the 100-line limit. Wiring order and
shared service instances are unchanged.
2026-05-29 15:02:44 +02:00
Zoltán Papp
61431801ea ui: extract subscribeAndStreamEvents to cut toastStreamLoop complexity
Move the event backoff op body into subscribeAndStreamEvents and the
per-event fan-out into dispatchSystemEvent, bringing toastStreamLoop under
the 20 cognitive-complexity limit. No behavior change.
2026-05-29 14:58:17 +02:00
Zoltán Papp
02e3cb9987 ui: document why startStatusNotifierWatcher is empty on non-Linux
Explain that macOS/Windows have a native tray and the SNI+XEmbed bridge is
Linux-WM-only, so the body is intentionally empty to let main.go call it
unconditionally across all build targets.
2026-05-29 14:56:49 +02:00
Zoltán Papp
7a78b9df8a ui: extract subscribeAndStreamStatus to cut statusStreamLoop complexity
Move the status backoff op closure body into a method so the nested
closure no longer carries the stream loop and its conditionals, bringing
statusStreamLoop under the 20 cognitive-complexity limit. No behavior change.
2026-05-29 14:49:39 +02:00
Zoltán Papp
1416a2e160 ui: reduce cognitive complexity in tray/feed/xembed status handlers
Extract helpers to bring three methods under the 20 cognitive-complexity
limit without changing behavior:

- DaemonFeed.statusStreamLoop: split out handleStatusRecvErr and emitStatus
- Tray.applyStatus: split out consumePendingConnectLogin and
  refreshMenuItemsForStatus
- xembedHost.flattenMenu: split out menuItemFromLayout plus propString /
  propBool / propInt32 dbusmenu property accessors
2026-05-29 14:45:58 +02:00
Zoltán Papp
88db1724bf Merge branch 'main' into ui-refactor 2026-05-29 14:39:40 +02:00
Eduard Gert
d0d7252c24 update settings bottom bar height 2026-05-29 14:25:36 +02:00
Eduard Gert
9dc9e7184e add first run lang detection 2026-05-29 14:24:20 +02:00
Eduard Gert
1985caf993 add os detection 2026-05-29 14:04:45 +02:00
Eduard Gert
16570b3223 prevent content flash in settings 2026-05-29 13:43:48 +02:00
Eduard Gert
967235e964 update button size and weight 2026-05-29 13:10:38 +02:00
Eduard Gert
7d876571da update CLAUDE.md 2026-05-29 13:08:06 +02:00
Eduard Gert
e6a624dcee preload settings window, prevent opening hidden windows on macos 2026-05-29 13:07:34 +02:00
Zoltán Papp
bee92f5fcd ui/frontend: update StatusContext for Peers → DaemonFeed rename
Missed in the previous commit. The StatusContext is the only frontend
consumer of the renamed service (the modules/main/.../peers/Peers.tsx
React component is a different identifier — unchanged).
2026-05-28 21:43:44 +02:00
Zoltan Papp
f4914fdfcc build: replace Wails3 scaffolding placeholders with NetBird identity
The build/config.yml that wails3 init scaffolded shipped with 'My Company',
'My Product', 'com.mycompany.myproduct' and '(c) 2025, My Company' template
defaults. The per-platform assets generated from it (Info.plist,
Info.dev.plist, info.json, nsis/wails_tools.nsh) carried the same strings,
which were visible in macOS Finder Get Info, Windows .exe Properties and
the NSIS installer.

Updated to the NetBird identity used by the legacy Fyne UI on main:

- companyName / copyright   -> 'NetBird GmbH' (matches main release.yml's
                              COPYRIGHT env passed to goversioninfo)
- productName               -> 'NetBird'
- productIdentifier         -> 'io.netbird.client' (matches CFBundleIdentifier)
- description               -> 'NetBird desktop client'
- darwin NSHumanReadableCopyright   -> 'NetBird GmbH'
- windows LegalCopyright            -> 'NetBird GmbH'
- nsis INFO_COPYRIGHT               -> 'NetBird GmbH'

Version fields (0.0.1) are left in place: release builds get the real
version via goversioninfo (Windows) and sign-pipelines (macOS .app),
so the placeholder is only visible in local task package / task run
output and doesn't reach release artifacts.
2026-05-28 21:32:16 +02:00
Zoltán Papp
2cdc6ef1c6 ui: split tray.go into feature files, rename Peers service to DaemonFeed
The 1542-line tray.go grew into a 14-feature kitchen sink. Split it
into feature-coherent same-package siblings, give the daemon-stream
service a name that matches what it actually does, and trim the
cargo-cult context.WithCancel pattern from click handlers.

File layout (tray.go: 1542 → ~470 lines):
  - tray_status.go    onStatusEvent / applyStatus / status indicator
  - tray_icon.go      applyIcon / iconForState (tray icon painting)
  - tray_events.go    onSystemEvent + eventTitle / titleCase, plus a
                      shouldSkipSystemEvent helper that names the
                      three "daemon notification we don't surface"
                      filters
  - tray_session.go   session-expiry row + warning notification flow +
                      handleSessionExpired (moved from tray.go)
  - tray_profiles.go  loadConfig / loadProfiles / switchProfile
  - tray_exitnodes.go exit-node submenu (rebuild / refresh / toggle)

Mutex split: the kitchen-sink t.mu becomes four domain-scoped mutexes
so a long-running gRPC call in one domain can't block status-push
readers in another:
  - statusMu        connected / lastStatus / lastDaemonVersion /
                    lastNetworksRevision / pendingConnectLogin
  - sessionMu       sessionExpiresAt (read by the 30s ticker,
                    written by applySessionExpiry on every status push)
  - profileMu       activeProfile / activeUsername /
                    notificationsEnabled / switchCancel
  - exitNodesMu     row cache (read in reapplyMenuState's Repaint copy)
  - exitNodesRebuildMu  serialises ListNetworks + submenu rebuild +
                        SetMenu (already separate, kept)

Service rename: the "Peers" service handled the daemon's full
SubscribeStatus snapshot (peers, daemon version, management/signal
link state, networks revision, SSO deadline) plus the SubscribeEvents
notification stream and the profile-switch suppression filter. Peers
was a misleading name for a daemon-stream fan-out service. Rename to
DaemonFeed in services/, profileswitcher's stored reference, the
TrayServices struct, main.go wiring, and every doc comment that
referenced it. peers.go → daemon_feed.go. The Status.Peers field
itself (the peer list in the snapshot) is unchanged.

Event constant renames (wire strings unchanged so the frontend keeps
working without regenerating bindings beyond the rename):
  - EventStatus → EventStatusSnapshot
    Payload is a full Status struct (daemon-wide snapshot), not just
    a state-change ping — name the value-shape.
  - EventSystem → EventDaemonNotification
    Payload is a daemon SystemEvent meant to drive an OS toast or a
    Recent Events row. "System" was too generic; "Notification"
    matches what consumers do with it.

Concurrency fixes:
  - WaitExtendAuthSession now preempts a previous in-flight wait
    via the existing SetWaitCancel/CancelWait infrastructure on
    PendingFlow, the same pattern WaitSSOLogin uses. The previous
    waiter exits with codes.Canceled; the authsession service
    translates that to ExtendResult{Preempted: true} so the tray
    and the about-to-expire dialog stay silent on the losing flow
    instead of showing a false-failure toast. Without this, both
    a tray "Extend now" click and a dialog "Stay connected" click
    on the same deadline started two parallel IdP polls, and
    whichever lost the device-code check painted a bogus error.
  - mgmClient.ExtendAuthSession drops the dead backoff retry loop.
    The loop only retried on codes.Canceled, but the inner mgmCtx
    was derived from context.Background() and never cancelled, so
    every real error went straight to backoff.Permanent on the
    first attempt. Replace with a single
    context.WithTimeout(c.ctx, ConnectTimeout) call; daemon
    shutdown now interrupts the RPC and behaviour on real errors
    is unchanged.

Click-handler hygiene: six call sites used the cargo-cult
context.WithCancel(context.Background()) + defer cancel() pattern
without ever calling cancel() externally. Replace with
context.Background() directly (loadConfig, loadProfiles,
runExtendSession, dismissSessionWarning, handleConnect's Up,
handleDisconnect's Down). The one site that genuinely needs the
cancel — switchProfile, which stores it in t.switchCancel so
handleDisconnect can preempt the switch — keeps WithCancel.

Helper extraction: shouldSkipSystemEvent groups the three
"daemon notification we drop on the floor" checks
(new_version_available metadata, progress_window metadata, the
::/0 partner of an exit-node default-route event) behind a single
named predicate. Each had a comment explaining why; collecting
them moves the rationale into the helper docstring and shrinks
onSystemEvent to a router.
2026-05-28 21:26:57 +02:00
Zoltán Papp
3279b705fe session-extend: drop dead retry loop in mgmClient.ExtendAuthSession
The backoff loop only retried on codes.Canceled, but mgmCtx was derived
from context.Background() and never cancelled by anything — so every
real error path (Unavailable, DeadlineExceeded, etc.) went through
backoff.Permanent on the first attempt. The loop was a no-op wrapper
that just held the call open for the daemon's lifetime regardless of
shutdown.

Replace with a single context.WithTimeout(c.ctx, ConnectTimeout) call.
Daemon shutdown now interrupts the RPC; behaviour on real errors is
unchanged.
2026-05-28 19:28:15 +02:00
Zoltán Papp
e94a4cbce5 session-extend: preempt previous WaitExtendAuthSession on new wait
When the tray "Extend now" notification action and the about-to-expire
dialog both start a flow for the same deadline, the daemon was running
two independent IdP polls and the older one surfaced an InvalidArgument
toast as soon as the second RequestExtend overwrote the pending flow.

Follow the WaitSSOLogin pattern: at the top of WaitExtendAuthSession
cancel the previous wait (the SetWaitCancel/CancelWait pair on
PendingFlow already existed but was unused), then register the new
wait's cancel. Preempted callers exit with codes.Canceled; the
authsession service translates that into ExtendResult{Preempted: true}
so the tray and the React dialog can stay silent on the losing flow
instead of showing a false-failure toast / error dialog.
2026-05-28 19:17:46 +02:00
Eduard Gert
c1db8ab0ab add manage profiles to tray 2026-05-28 18:04:38 +02:00
Eduard Gert
2bf945e745 remove unused packages 2026-05-28 17:18:49 +02:00
Eduard Gert
4556d52a60 fix view mode toggle 2026-05-28 16:36:15 +02:00
Eduard Gert
51b243bdfa remove unused stuff, refactor frontend folder structure 2026-05-28 16:26:13 +02:00
Eduard Gert
e09bc8894d Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor 2026-05-28 15:45:51 +02:00
Zoltan Papp
55c1f44fb0 build: drop -buildvcs=false so go embeds vcs.revision into the ui binary
All Wails3 Taskfiles passed -buildvcs=false to go build, which disables
the automatic VCS info embedding Go 1.18+ does by default. As a result
runtime/debug.ReadBuildInfo() returned an empty vcs.revision in our
netbird-ui binary, so the upcoming version.NetbirdCommit() helper from
PR #6263 could not display the git sha for dev builds.

Removed from build:native in all three platform Taskfiles plus the
Windows build:console and the Dockerfile.cross cross-compile script.
go version -m bin/netbird-ui now reports vcs.revision and vcs.modified.
2026-05-28 15:22:05 +02:00
Eduard Gert
ac8d417c12 update exit node tab 2026-05-28 14:43:28 +02:00
Eduard Gert
dccc0ebe4b update resources tab 2026-05-28 14:28:51 +02:00
Zoltán Papp
35498c572a ci: bump node to v22 in release workflow
pnpm 11 requires Node.js >= 22.13 (uses node:sqlite, added in 22.5),
but the release workflow still pinned Node 20. After bumping pnpm to
v11 in the previous commit, the frontend build hook now fails with
ERR_UNKNOWN_BUILTIN_MODULE 'node:sqlite' until Node also moves to 22.
2026-05-28 13:57:30 +02:00
Zoltán Papp
cda621bb27 ci: bump pnpm to v11 in release workflow
The frontend uses pnpm 11 (packageManager field, v11 lockfile, and the
allowBuilds key in pnpm-workspace.yaml is a pnpm 10+ feature), but the
release_ui job's pnpm/action-setup was pinned to v9. v9 rejects the
workspace file with 'packages field missing or empty' before the
frontend build hook can run.
2026-05-28 13:51:43 +02:00
Zoltán Papp
d57b30f8d5 Merge branch 'main' into ui-refactor 2026-05-28 13:43:19 +02:00
Zoltan Papp
d82b950718 frontend: approve esbuild postinstall via pnpm-workspace.yaml
pnpm 11 blocks dependency build scripts by default and exits non-zero
when any are skipped, which made task build fail at install:frontend:deps.
esbuild's postinstall is required to fetch the platform-specific binary.
2026-05-28 13:40:43 +02:00
braginini
3bd058d425 Use old version of the Dock icon 2026-05-28 12:45:36 +02:00
Zoltan Papp
0082f51830 i18n: pluralize exit node nav title 2026-05-28 11:46:19 +02:00
Zoltan Papp
e4420b1f96 tray: separator between troubleshoot and version info in about submenu 2026-05-28 11:44:13 +02:00
Zoltan Papp
a5635f8825 tray: use yellow connecting dot for needs-login state 2026-05-28 11:41:22 +02:00
Zoltan Papp
966fbec119 routemanager: enforce a single selected exit node
Exit nodes are mutually exclusive, but the RouteSelector stores routes with
default-on semantics, so every available exit node reported as selected at once.

Reconcile exit-node selection on each network map (and on runtime selection):
keep at most one selected — the user's persisted pick, else whatever management
marks for auto-apply (SkipAutoApply=false), else none. Never auto-activate an
exit node the map doesn't request; it stays off until the user picks it.

The server deselects sibling exit nodes when the user activates one (leaving
non-exit routes untouched), and the tray/React exit-node toggle now appends so
activating an exit node no longer wipes network-route selections.
2026-05-27 20:48:16 +02:00
Zoltan Papp
f693d268b4 tray: selectable exit nodes + push-based network list refresh
Make the tray Exit Node submenu selectable (mutually exclusive, sourced from
ListNetworks by NetID) instead of read-only.

Add networksRevision to the status snapshot, bumped by the route manager on
network-map and selection changes, so the tray and the React NetworksContext
re-fetch ListNetworks via the push stream instead of polling. The peer-status
route list only carries chosen routes, so a candidate exit node appearing or
disappearing would otherwise never reach the UI.
2026-05-27 20:48:16 +02:00
Eduard Gert
09f4109b01 update peers ui 2026-05-27 18:01:06 +02:00
Eduard Gert
ad7d7fa881 change font weight 2026-05-27 17:03:31 +02:00
Eduard Gert
b84c7618e7 fix viewmode height, update other ui stuff 2026-05-27 16:40:57 +02:00
Eduard Gert
ec5da43d73 persist viewMode across restarts 2026-05-27 15:52:15 +02:00
Eduard Gert
a8ad73d2d9 add shortcuts in tray for quit and settings item 2026-05-27 15:37:05 +02:00
Eduard Gert
a241112a1d disable resize 2026-05-27 15:35:51 +02:00
Eduard Gert
e62dff0f66 add github, docs links etc. to settings about page 2026-05-27 15:35:40 +02:00
Eduard Gert
5cecca2c23 store viewmode in ui preferences 2026-05-27 15:21:51 +02:00
Eduard Gert
0e83d2ad94 add peer details 2026-05-27 14:52:19 +02:00
Zoltan Papp
004a305e46 tray: add 30s ticker to keep session-expiry countdown fresh 2026-05-27 00:33:09 +02:00
Zoltan Papp
c77e5cef85 tray: revert on-open click handler — OpenMenu freezes tray and React
Binding OnClick/OnRightClick to call OpenMenu() on macOS routes the menu
open through showMenu(), which runs the blocking [button mouseDown:] inside
a dispatched block on the serial main GCD queue. While the menu is open that
block never returns, starving every other main-queue task — both tray item
updates and the webview event delivery that drives React freeze until the
menu closes.

Revert to the pre-d9f0189 state: no click handlers bound, native NSStatusItem
auto-show for left-click, Wails default rightClickHandler for right-click.
refreshSessionExpiresLabel() is kept for the follow-up fix.
2026-05-27 00:24:46 +02:00
Zoltan Papp
13179081d2 Merge branch 'main' into ui-refactor 2026-05-26 23:41:18 +02:00
Zoltan Papp
2d3c8fc555 tray: drop dead iconMenuNetbird and openRoute after menu rework
The menu reorganisation removed the About brand-mark bitmap and rerouted
every openRoute caller to WindowManager auxiliary windows, leaving both
the iconMenuNetbird embed (all three platforms) and the openRoute helper
unreferenced. Remove them so the unused linter passes.
2026-05-26 23:25:45 +02:00
Zoltan Papp
61aa3a53ed tray: re-enable Exit Node menu item when candidates arrive post-connect
The parent Exit Node item's enablement was only refreshed on icon/status
transitions. The daemon ships peer routes in a later snapshot than the
Connected status text, so after a profile switch the candidate list flips
empty to non-empty while the status string is unchanged — leaving the item
greyed and the freshly painted rows unreachable. Re-evaluate enablement in
the exitNodesChanged branch too.
2026-05-26 23:15:03 +02:00
Zoltan Papp
80d6df6260 tray: rework menu layout, exit-node submenu, session countdown wording
- Reorder the menu: status, Connect/Disconnect, profile block, Open
  NetBird, Exit Node, then Settings… / Help & Support / Quit NetBird.
- Rename About → Help & Support, Quit → Quit NetBird, Settings → Settings…
  (ellipsis flags the window-opening action per the macOS HIG); drop the
  brand icon from Open NetBird; enable Documentation (opens docs.netbird.io)
  and add a Troubleshoot entry that deep-links the Settings window.
- Exit Node is now a submenu listing only peers that advertise a default
  route (0.0.0.0/0 or ::/0), sorted case-insensitively; the row stays
  visible but greyed when the tunnel is down or no candidate exists.
- Session row reads "Session expires in <n minutes/hours/days>" and
  recomputes on menu open so the countdown tracks wall time between the
  daemon's status pushes.
2026-05-26 23:15:03 +02:00
Zoltan Papp
53bbc2d551 session: clear stale SSO deadline on teardown and after expiry
The session deadline lived in two sinks kept in sync by hand:
ApplySessionDeadline wrote both the (engine-scoped) sessionwatch.Watcher
and the (server-scoped) peer.Status recorder. The clear paths only
touched the watcher, so the recorder — which is what the Status RPC /
SubscribeStatus snapshot the UI reads from — kept reporting a deadline
that had gone stale, surfacing as a frozen "expires in …" countdown.

Two cases were leaking:
- Profile switch / Down: the watcher is recreated per engine but the
  recorder outlives it, so a switch to a profile whose server sends no
  deadline left the previous profile's value in place.
- In-place expiry: the watcher arms warning timers at T-WarningLead and
  T-FinalWarningLead but nothing at the deadline itself, so once the
  moment passed the recorder kept the now-past value indefinitely.

Make the watcher the single writer of the recorder deadline (Update /
clearLocked / Close all route through SetSessionExpiresAt) so teardown
clears it, and guard GetSessionExpiresAt to report a past deadline as
none so in-place expiry stops painting a stale countdown.
2026-05-26 23:15:03 +02:00
Zoltan Papp
d9f0189b57 tray: reorganise menu, refresh expiry countdown on open
Layout changes:
- Drop "Debug Bundle" row; reach the flow via the in-window Settings UI.
- Move the brand-mark icon from the About row to "Open NetBird".
- Collapse Settings / Exit Node / About into a single block, with the
  Settings → Exit Node order swap to put the configuration entry first.
- Relocate Connect / Disconnect to the bottom block, sharing its
  separator with Quit. Drops the connectSeparator field + lastMenuItem
  helper that only existed to suppress the daemon-unavailable double
  separator in the old position.

Countdown freshness: the daemon's Status snapshots arrive too coarse to
keep a minute-grained "Expires in …" row honest while the menu is
closed. Wails v3 alpha 95 does not expose a public NSMenu needsUpdate
hook, so the tray binds OnClick / OnRightClick and recomputes the label
from cached sessionExpiresAt just before the menu paints. macOS and
Windows right-click additionally call OpenMenu() to restore the native
auto-show that binding the handler suppresses; Linux's dbusmenu host
paints the menu itself.
2026-05-26 23:15:03 +02:00
Eduard Gert
91e0520f27 move locales to client/ui/i18n 2026-05-26 12:34:01 +02:00
Eduard Gert
67a1f3c4fe add peers, networks and exit node list (wip) 2026-05-26 12:09:08 +02:00
Zoltan Papp
b6d20edfeb tray: show NetBird brand mark next to About on macOS
NSMenuItem rejected the dedicated netbird-menu-24.png brand mark
(rendered muddy) and the full 256x256 brand PNG (stretched the row).
Ship an 18x18 sips-downscale of assets/netbird.png — same source the
legacy Fyne client used for its About row — to sit visually alongside
the cap-height of the surrounding text.
2026-05-26 11:40:30 +02:00
Zoltan Papp
18d0019332 tray: drop Networks menu item, make session-expiry row open extend flow
Networks row removed from the tray; Exit Node remains the only routed-
state entry. Clicking the "Expires in …" row now opens the
SessionAboutToExpire window seeded with the actual remaining seconds, so
users can extend the SSO session proactively instead of waiting for the
daemon's T-FinalWarningLead auto-prompt.
2026-05-26 11:40:30 +02:00
Eduard Gert
ecee7df5d8 Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor 2026-05-26 11:27:17 +02:00
Eduard Gert
1d783c33d9 adjust left offset in ip and fqdn 2026-05-26 11:26:51 +02:00
Eduard Gert
b14feef1d7 add copy to clipboard 2026-05-26 11:23:08 +02:00
Zoltán Papp
0935a5675d tray: move session-expiry row under profile email, hide separator when daemon unavailable
- Relocate the session-expiry row from below the status item to below the
  profile email so active profile, email, and session deadline form one block.
- Rename the label to "Expires in {remaining}" (en/hu/de).
- Capture the Connect/Disconnect separator via lastMenuItem and hide it when
  both action rows are hidden (daemon unavailable), avoiding two adjacent
  separators with nothing between them.
2026-05-26 10:59:36 +02:00
Eduard Gert
4818599a93 sort peers 2026-05-26 09:50:18 +02:00
Eduard Gert
f8c107b087 update peers filter, fix duplicate url open in dev 2026-05-26 09:35:03 +02:00
Eduard Gert
d624c2db74 fix i18n label 2026-05-22 16:36:07 +02:00
Eduard Gert
513ecd456c remove mock peers 2026-05-22 16:28:24 +02:00
Eduard Gert
8f957ff41a fix scroll in settings 2026-05-22 16:09:55 +02:00
Eduard Gert
598fcbd817 remove unused lang icons, disable text selection 2026-05-22 15:59:27 +02:00
Eduard Gert
17a365926d update auto update wordings, add update available into ui 2026-05-22 13:04:04 +02:00
Eduard Gert
577ce6deb5 fix connect flow in tray 2026-05-22 10:16:13 +02:00
Eduard Gert
580cfa0dc5 add default and advanced resize 2026-05-22 09:53:08 +02:00
Zoltán Papp
8d4f35352f skip About-row brand mark on macOS
NSMenuItem.setImage stretches the row to the leading image's pixel
size regardless of the surrounding rows, so any non-empty bitmap on
the About entry made it visibly taller than the rest of the tray
menu — leaving 16, 18 or 22 px versions all looking wrong next to
the unadorned rows above and below.

Drop the macOS brand mark and gate the SetBitmap call on a non-empty
byte slice; iconMenuNetbird is now nil on macOS, so the About row
falls back to text only. Windows and Linux still ship the brand mark
through their per-platform embed files.
2026-05-21 17:01:08 +02:00
Zoltán Papp
85029898a5 per-platform tray menu icons and Windows-specific status row
The Windows menu renderer paints leading bitmaps into the Win32
check-mark slot (SetMenuItemBitmaps), which differs from how Cocoa
and GTK handle NSMenuItem.image / menu-row icons:

  - SM_CXMENUCHECK sizing: Windows expects ~16x16 at 100% DPI in the
    check-mark slot and visually overflows the row for anything bigger.
  - Disabled-state mask: Windows desaturates both the row text and the
    bitmap when MFS_DISABLED is set, so a disabled informational row
    renders the coloured status dot in greyscale.

Per the platform icon guidelines:

  Platform | Size           | Notes
  ---------|----------------|-----------------------------------------
  Windows  | 16x16          | check-mark slot, status row stays enabled
  macOS    | 22x22 (18-22)  | NSMenuItem leading image, HIG
  Linux    | 24x24 (22-48)  | GTK4 menu-row icon channel

Changes:

  * Split the menu-row icon embeds into icons_menu_{windows,darwin,linux}.go
    so each platform pulls its own size; the brand mark is rendered from
    assets/svg/netbird-menu.svg (new vector source) at 16/22/24 px with
    Inkscape, and the Windows status dots ship as 8x8 content centred on
    a 16x16 transparent canvas (the renderer upscales the bitmap, so the
    padding keeps the dot visually proportional to the row text).

  * Introduce statusRowEnabled() in tray_status_enabled_{windows,other}.go:
    true on Windows so the disabled-state mask does not strip the dot's
    colour; false on macOS/Linux where disabled menu rows fade the label
    without desaturating the leading bitmap, signalling that the row is
    informational.

  * Add an icon to the About submenu using the same brand mark.
2026-05-21 16:41:52 +02:00
Zoltán Papp
c3aeb5be15 force dark window theme on Windows 2026-05-21 14:59:00 +02:00
Eduard Gert
df61f22d96 add error msg to profile context and auto update 2026-05-21 09:49:32 +02:00
Eduard Gert
32df29bbd4 Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor
# Conflicts:
#	client/ui/frontend/src/screens/Update.tsx
2026-05-21 09:34:45 +02:00
Zoltán Papp
0a458ead8b port xembed tray popup menu from gtk3 to gtk4 2026-05-20 19:53:24 +02:00
Zoltan Papp
aab8274b1a clear connect-action latch when external disconnect cancels Connecting
The main-window toggle stayed visually stuck on "Connecting" when the
user clicked Connect in the UI and then clicked Disconnect in the
tray (or the daemon was otherwise cancelled mid-Connecting).

Repro: open the main window, click the toggle to Connect, then while
the daemon is still in Connecting click Disconnect in the tray menu.
The tray and daemon agree the session is Idle, but the React toggle
keeps painting "Connecting" until the next manual interaction.

Root cause is in ConnectionStatusSwitch.tsx. The component holds an
`action` latch ("connect" | "logging-in" | "disconnect" | null) so the
toggle can show an optimistic transitional state while the daemon
catches up. The connState memo treats `action === "connect"` plus any
non-Connected daemon state as Connecting:

    if ((action === "connect" || action === "logging-in") &&
        daemonState !== "Connected") {
        return ConnectionState.Connecting;
    }

The effect that releases the latch only cleared it on `Connected` or
`DaemonUnavailable`. There was no branch for "the connect flow was
cancelled externally and the daemon is back at Idle", so the latch
remained set forever and the optimistic Connecting state never
collapsed.

Fix: add a `sawConnectingRef` that flips to true the first time the
daemon reports Connecting during an active "connect" action, and
resets when `action` returns to null. When `action === "connect"` and
the daemon flips from a state we'd observed as Connecting back to
Idle, clear the latch so connState falls through to Disconnected.

Other paths are untouched:
- Successful connect still clears on Connected.
- NeedsLogin still hands off to driveLogin.
- DaemonUnavailable still clears via the `unreachable` branch.
- The `"logging-in"` action is intentionally not handled here; Login's
  internal Down flaps the daemon through Idle and driveLogin's
  .finally remains the sole clearer for that latch.
- The `"disconnect"` action's Idle/Disconnected/unreachable clear is
  unchanged.
2026-05-20 19:44:02 +02:00
Zoltan Papp
d3b660afba classify daemon login errors and surface localised dialogs
The daemon returns gRPC errors whose message is a wrapped mgm + JWT
stack (e.g. "invalid jwt token, err: token could not be parsed: ...").
Showing that in a native dialog is unreadable. Connection now maps the
substrings it recognises to a ClientError{code, short, long} so the UI
can render a localised summary plus a Details: block carrying the raw
daemon text. formatErrorMessage on the TS side reads the structured
payload from Wails' Error.cause (or the JSON-stringified Error.message)
and falls back to plain Error.message for callers not yet migrated.

Also bumps Wails to v3.0.0-alpha.95.
2026-05-20 19:13:13 +02:00
Zoltán Papp
341848b1ae fix lint issues in session watcher tests and status humaniser 2026-05-20 18:46:56 +02:00
Eduard Gert
414e7815e4 update default view icon, remove capitalize from profile name 2026-05-20 16:45:06 +02:00
Zoltán Papp
ef6b4f7538 add SSO session extend flow
Adds an end-to-end SSO session-extension feature: the management server
publishes per-peer session deadlines on every Login/Sync, a new
ExtendAuthSession RPC refreshes the deadline using a fresh JWT without
tearing down the tunnel, and the daemon tracks the deadline locally so
the UI can fire a T-10min warning toast with an interactive "Extend now"
action.
2026-05-20 16:43:14 +02:00
Eduard Gert
a7b26e3c0d add updating dialog 2026-05-20 16:20:40 +02:00
Eduard Gert
42534b24c5 fix scrollarea inside settings 2026-05-20 13:43:18 +02:00
Eduard Gert
2aea1f7bb5 Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor 2026-05-20 13:38:34 +02:00
Eduard Gert
620233a7ac update dropdown ui padding, remove unused stuff 2026-05-20 13:38:23 +02:00
Eduard Gert
1c15e9976b add profiles tab to settings 2026-05-20 13:17:13 +02:00
Zoltán Papp
f04e2bada8 [ci] Switch CI deps to GTK4 / WebKitGTK 6.0
Wails v3 alpha.94 switched its default Linux backend from GTK3 +
WebKit2GTK 4.1 to GTK4 + WebKitGTK 6.0 (the GTK3 path is now gated
behind a `gtk3` build tag). cgo files that the binary, the tests, and
the lint job all parse now request `pkg-config --cflags gtk4
webkitgtk-6.0 ...`, so the existing libgtk-3-dev + libwebkit2gtk-4.1
apt deps no longer satisfy them — lint, unit tests, and the linux
release build all fail with `Package 'gtk4' ... not found`.

Replace the apt deps across the four workflows that build/lint the
client tree (golangci-lint, golang-test-linux, release, and the wasm
lint job that also walks client/) with libgtk-4-dev + libwebkitgtk-6.0-dev
+ libsoup-3.0-dev. Both packages are available from jammy (22.04 LTS)
onwards, so existing ubuntu-22.04 runners stay valid.
2026-05-20 12:46:37 +02:00
Zoltán Papp
1d88faf66f [ci] Stage WebView2 bootstrapper in test_windows_installer
client/installer.nsis:317 calls `File "MicrosoftEdgeWebview2Setup.exe"`
and client/netbird.wxs references the same payload. In the release
pipeline that file is generated by `wails3 generate webview2bootstrapper`
inside netbirdio/sign-pipelines; the netbird repo's test_windows_installer
job never ran that step, so makensis aborted with:

  Error in macro nb.webview2runtime on macroline 21
  Error in script "...\client\installer.nsis" on line 325

Mirror the sign-pipelines recipe: set up Go, install wails3 (version
derived from go.mod so the bootstrapper always matches the linked
runtime), then stage the bootstrapper into client/ before the makensis
step runs.
2026-05-20 12:17:11 +02:00
Zoltán Papp
84093af1f0 Bump wails/v3 to v3.0.0-alpha.94
Picks up alpha.92..94 fixes; the binding generator and the
@wailsio/runtime npm package (pinned to "latest") stay compatible.
Brings tranzitive upgrades along (go-git, golang.org/x/exp,
golang.org/x/mod, golang.org/x/text, golang.org/x/tools, pjbgf/sha1cd).
2026-05-20 12:11:28 +02:00
Zoltán Papp
34a4744565 [ci] Wire wails3 bindings generation into darwin UI release
The release_ui_darwin job builds the macOS UI bundle from
.goreleaser_ui_darwin.yaml, but cccb0e92 only added the wails3 CLI
install + bindings-regen hook to the Linux side (release.yml release_ui
job and .goreleaser_ui.yaml). The darwin counterpart still ran pnpm
build against the gitignored, empty bindings/ directory and failed with
~40 TS2307 "Cannot find module '@bindings/...'" errors.

Mirror the Linux setup on darwin: install wails3 from the version
pinned in go.mod, and run `wails3 generate bindings -clean=true -ts`
as the first goreleaser before-hook so vite can resolve @bindings/* by
the time pnpm build starts.
2026-05-20 11:20:30 +02:00
Eduard Gert
b79b62bee4 add default and advanced view items into dropdown 2026-05-20 09:39:35 +02:00
Eduard Gert
bec4eb326a update new profile modal 2026-05-19 18:53:19 +02:00
Eduard Gert
8748f3810d update profile ui 2026-05-19 18:27:05 +02:00
Eduard Gert
1c5254cb31 update profile ui 2026-05-19 14:21:14 +02:00
Zoltán Papp
3f8cd29006 Merge remote-tracking branch 'origin/main' into ui-refactor 2026-05-18 23:31:13 +02:00
Eduard Gert
ca48de549e make dialogs draggable, disable selecting text 2026-05-18 16:34:38 +02:00
Eduard Gert
5b71a4f2ad update dialogs, hide main window on browser login, keep state as disconnected when needslogin 2026-05-18 16:31:59 +02:00
Eduard Gert
741ce8581d fix open settings in tray, prevent loading profiles when daemon is down 2026-05-18 13:07:34 +02:00
Zoltan Papp
6b44d65cac report daemon-down as DaemonUnavailable on initial Peers.Get and gate UI
- Peers.Get returns Status{Status: DaemonUnavailable} on Unavailable
  instead of an error so the React useStatus initial refresh picks up
  the same string the live event stream emits — the overlay no longer
  depends on receiving the synthetic event during boot.
- ProfileContext.refresh swallows Unavailable so the redundant
  "Load Profiles Failed" popup does not overlap the overlay.
- Tray Profiles submenu is disabled while the daemon is unavailable,
  matching the existing settings/debug/connect gating.
- gRPC client uses a 5s ConnectParams MaxDelay; the default 120s cap
  was keeping the SubChannel in backoff for tens of seconds after the
  daemon came back, masking the recovery.
2026-05-18 12:33:46 +02:00
Eduard Gert
f84b1df857 remove unused import 2026-05-18 11:37:55 +02:00
Eduard Gert
c24349e4f1 add overlay when daemon not available 2026-05-18 11:37:42 +02:00
Eduard Gert
7f7bee630f update about settings dev version, keep profile switch in sync between ui and tray 2026-05-18 10:56:27 +02:00
Eduard Gert
4e0eb9f2d4 Merge remote-tracking branch 'origin/ui-refactor' into ui-refactor 2026-05-18 10:41:12 +02:00
Eduard Gert
38a367e0cd update markdown files 2026-05-18 10:39:39 +02:00
Eduard Gert
78fb15e327 update profile context 2026-05-18 10:39:32 +02:00
Eduard Gert
35e58a2796 update connection switch 2026-05-18 10:39:22 +02:00
Eduard Gert
a6278936af replace openRoute with Event.Emit for needsLogin 2026-05-18 10:39:03 +02:00
Eduard Gert
32f62f3ed8 add profile switched event 2026-05-18 10:38:13 +02:00
Zoltán Papp
7fae703a27 [client/ui] Port IPv6 toggle and paired default-route filter to Wails UI
Brings two main-side PRs' UI behavior across the Fyne→Wails rewrite:

- #5631 (IPv6 overlay support): add "Enable IPv6" row to the polished
  SettingsNetwork tab; the legacy screens/Settings.tsx already had it,
  but modules/settings/SettingsNetwork.tsx (the user-visible Settings
  window) was missing the toggle.
- #6150 (mirror v4 exit selection onto v6 pair): replace the literal
  "0.0.0.0/0" || "::/0" filter in screens/Networks.tsx with an
  isDefaultRoute() helper that handles the daemon's merged-range
  display string (e.g. "0.0.0.0/0, ::/0"), so paired v4/v6 exit
  nodes are classified correctly.
2026-05-18 10:25:18 +02:00
Zoltán Papp
f468f15a30 Merge branch 'main' into ui-refactor
# Conflicts:
#	client/ui/network.go
2026-05-18 10:24:31 +02:00
Eduard Gert
5bdccfe8f4 add i18n to frontend 2026-05-15 16:22:14 +02:00
Zoltan Papp
cccb0e9230 [ci] Generate Wails bindings in release, bump wails to alpha.91
The bindings under client/ui/frontend/bindings are gitignored (1ebb507),
so the release UI job has to regenerate them before pnpm build — the
@wailsio/runtime Vite plugin refuses to build without them. Add a
wails3 CLI install step (version derived from go.mod via go list -m,
so it stays in sync with the runtime the binary links against), plus a
goreleaser before-hook that runs `wails3 generate bindings -clean=true
-ts` ahead of the existing pnpm install + pnpm build pair.

Bump github.com/wailsapp/wails/v3 to v3.0.0-alpha.91 in the process.
The @wailsio/runtime npm package stays at "latest" since the registry
only goes up to alpha.79 — the binding generator and the runtime stay
compatible across that gap until the binding shape changes.
2026-05-15 13:46:05 +02:00
Zoltan Papp
9d8eb76746 [client/ui] Replace update event fan-out with typed UpdateState API
The auto-update feature was driven by two narrow Wails events
(netbird:update:available and :progress) plus a SystemEvent-metadata
iteration on the React side. Both surfaces had to know the daemon
metadata schema (new_version_available, enforced, progress_window),
and the frontend had no pull endpoint to seed its state on mount.

Extract the state machine into a new client/ui/updater package, mirroring
how i18n and preferences are split between domain logic and a thin
services facade. The package owns the State type, the metadata-key
parsing, the mutex-guarded Holder, and the single netbird:update:state
event. services.Update keeps the daemon RPCs (Trigger, GetInstallerResult,
Quit) and gains GetState as a Wails pull endpoint.

Tray-side update behaviour moves out of tray.go into a dedicated
trayUpdater (tray_update.go): owns its menu item, OS notification,
click handler, and the /update window opener triggered by the
daemon's progress_window:show. tray.go drops three callbacks and four
fields, and reads hasUpdate through the updater.

Frontend ClientVersionContext now seeds from Update.GetState() and
subscribes to netbird:update:state; the status.events iteration and
metadata-key string literals are gone. UpdateAvailableBanner renders
only for the enforced && !installing branch and labels its action
"Install now"; UpdateVersionCard splits the install vs. download
branches by Enforced so the disabled flow routes to GitHub.
2026-05-15 13:31:17 +02:00
Eduard Gert
1ebb507cbb remove bindings from git 2026-05-15 13:01:19 +02:00
Eduard Gert
5411fa4350 remove old code, add german lang 2026-05-15 12:56:09 +02:00
Zoltan Papp
17cae1a75c [client/ui] Introduce localisation (i18n + preferences) feature packages
Adds a tray + React translation pipeline driven by a single JSON locale
tree (frontend/src/i18n/locales) embedded into the Go binary. The tray
re-renders on language switch via a Localizer that subscribes to the
preferences store.

Layout:
- client/ui/i18n: Bundle, LanguageCode, Language, errors, embedded-FS
  loader. Pure domain, no Wails/daemon deps.
- client/ui/preferences: Store + UIPreferences for user-scope UI state,
  persisted under os.UserConfigDir()/netbird/ui-preferences.json with
  atomic writes and a subscribe/broadcast channel.
- client/ui/services: thin Wails-binding facades (services.I18n,
  services.Preferences) so React sees ctx-first signatures.
- client/ui/localizer.go: tray bridge that owns the active language,
  exposes T()/StatusLabel() and re-paints the menu on prefs change.
- tray.go: every user-facing const replaced by translation keys via
  t.loc.T(...); menu rebuild + state replay on language switch.
- main.go: //go:embed all:frontend/src/i18n/locales, wires Bundle ->
  Store -> Localizer -> Wails facades in order.

Frontend API exposed via Wails bindings: I18n.Languages, I18n.Bundle,
Preferences.Get, Preferences.SetLanguage, plus the
netbird:preferences:changed event.

Includes regenerated Wails TS bindings (peers/profileswitcher/etc.
re-emitted as part of the build) and en/hu seed bundles.
2026-05-15 11:19:00 +02:00
Eduard Gert
c0b0eeb6ab update claude.md and rename windowmanager 2026-05-15 10:49:44 +02:00
Eduard Gert
d32721d7fc merge ui stuff 2026-05-15 10:20:51 +02:00
Eduard Gert
288f8dec08 Merge branch 'ui-refactor' into ui-refactor-ui 2026-05-15 10:16:30 +02:00
Eduard Gert
db8c9a0e30 add window manager 2026-05-15 10:14:01 +02:00
Zoltan Papp
505fcc7f7a [client/ui] Move profile-switch suppression from tray to Peers service
The optimistic Connecting paint and the Idle/stale-Connected
suppression lived in the tray's applyStatus, so only the tray got the
smoothed-out transition during a profile switch — the React Status
page (useStatus hook in frontend) subscribes to the same
netbird:status event and was seeing the raw daemon stream, complete
with the Disconnected blink.

Move the policy one layer up into the Peers service, between
SubscribeStatus and the Wails event bus, so every consumer downstream
sees the same filtered stream:

  * Peers gains BeginProfileSwitch / CancelProfileSwitch / shouldSuppress.
    BeginProfileSwitch sets the in-progress flag and emits a synthetic
    Connecting status so both the tray and React paint Connecting
    immediately. shouldSuppress swallows the daemon's stale Connected
    (peer-count teardown) and transient Idle (Down between flows)
    until Connecting / NeedsLogin / LoginFailed / SessionExpired /
    DaemonUnavailable indicates the new profile's flow has started,
    or a 30s safety timeout fires.

  * ProfileSwitcher.SwitchActive calls peers.BeginProfileSwitch when
    wasActive (prevStatus was Connected or Connecting) — the only
    cases where the daemon emits the blink-inducing sequence. Other
    prevStatuses already terminate cleanly on Idle.

  * Tray loses its switchInProgress fields, applyOptimisticConnecting
    helper, applyStatus suppression switch, and switchProfile's
    optimistic-paint call. handleDisconnect now calls
    Peers.CancelProfileSwitch alongside cancelling switchCancel, so
    the abort path bypasses the suppression filter and the daemon's
    Idle paints through immediately.

The full prevStatus -> action / optimistic label / suppressed events
matrix now lives in the ProfileSwitcher struct godoc, with the
suppression-rule-per-incoming-status table on the Peers struct
godoc — together they describe the click-time policy and the
stream-filter behaviour without duplication.

Wails bindings need regenerating to pick up Peers.BeginProfileSwitch
and Peers.CancelProfileSwitch.
2026-05-15 10:01:26 +02:00
Zoltan Papp
0fe8764707 [client/ui] Optimistic Connecting on profile switch, status row disabled
Three UX fixes for the tray's profile-switch flow:

* Optimistic Connecting paint when switching from Connected/Connecting.
  Previously the daemon's Down step emitted Idle before the new
  profile's Up emitted Connecting, leaving the tray flashing
  "Disconnected" for the duration of the Down. switchProfile now sets a
  flag and synthesizes a Connecting paint at click time; applyStatus
  suppresses the transient Idle and the stale Connected updates that
  arrive during the old profile's teardown, clearing the flag only when
  the new profile's flow surfaces (Connecting, NeedsLogin, LoginFailed,
  SessionExpired, DaemonUnavailable, or a 30s safety timeout).

* Disconnect during an in-flight switch now actually disconnects. The
  switchCancel context cancels the ProfileSwitcher.SwitchActive
  goroutine so its queued Up RPC never fires, and the
  switchInProgress flag is cleared so the daemon's Idle push paints
  through immediately. Without this, the user's Disconnect click was
  followed seconds later by the switcher's Up bringing the new
  profile back online.

* The first menu row is informational only. SetEnabled(false) is
  re-applied to t.statusItem (initial build, applyStatus, and the
  optimistic paint) and the openRoute("/login") OnClick handler is
  dropped — every actionable transition flows through the
  Connect/Disconnect entries below.

The switchProfile and applyStatus godocs carry the full
prevStatus → suppressed-events / final-state transition tables so
future readers don't have to rebuild the policy from the code.
2026-05-14 15:44:30 +02:00
Zoltan Papp
c0e7c61c4b [client] Close giveUpChan in connectWithRetryRuns defer
The trailing close(giveUpChan) at the bottom of the function only ran on
the backoff.Retry path. The DisableAutoConnect path returned early via
the if-block, skipping the close entirely. That branch is hit whenever
the active profile has auto-connect disabled — so every Down for those
profiles waited the full 5s timeout in the Down RPC select (and twice
when two Downs queued up, since both snapped the same never-closing
chan).

Move close(giveUpChan) into the existing defer so it fires on every
exit path: DisableAutoConnect return, backoff.Retry return, or panic.
The close happens after clientRunning=false is committed under the
mutex, so a Down/Up that wakes on the chan-close doesn't observe a
half-state where the chan is closed but clientRunning is still true.

Updates the Down RPC comment to point at the deferred close as the
signal source, and reframes the 5s timeout warning as "the goroutine
is wedged in a slow teardown step" rather than the expected case.
2026-05-14 15:44:15 +02:00
Zoltan Papp
e4eedbe18f [client/ui] Mirror tray profile switch to user-side ProfileManager
The Fyne UI used to write the active profile to both fronts on every
switch (profile.go:264-273): the daemon SwitchProfile RPC for
/var/lib/netbird/active_profile.json, then profileManager.SwitchProfile
for the user-side ~/Library/Application Support/netbird/active_profile.
The Wails ProfileSwitcher only kept the first.

Without the user-side mirror, a UI tray switch updates the daemon's
state but the CLI ProfileManager.GetActiveProfile() still returns the
stale "default". The next "netbird up" then sends ProfileName="default"
in the Login/Up request, and the daemon silently switches back to
default, reverting whatever the user just picked in the tray.

Mirror the daemon switch with profilemanager.NewProfileManager().
SwitchProfile after the daemon RPC succeeds. The daemon stays the
authority — a user-side write failure is logged as a warning, not a
hard error.
2026-05-14 14:52:14 +02:00
Zoltan Papp
fc1db63fc3 [client/ui] Fix profile-submenu race, restore Connect re-auth flow
Three tray fixes after the SubscribeStatus stream refactor:

* loadProfiles now serializes via a dedicated profileLoadMu and runs
  AFTER the SetHidden/SetEnabled writes inside applyStatus's iconChanged
  block. Previously the status-driven refresh fired before the menu-item
  writes finished, so processMenu's clearMenu/re-add NSMenu rebuild
  raced against SetHidden on darwin — the Disconnect entry could end
  up visible-but-disabled even when applyStatus had requested it hidden.

* The status row is no longer a hidden "Login" entry. It now renders
  as a plain enabled label (so the text isn't greyed out) but has no
  OnClick handler — clicks are no-ops, matching the legacy Fyne UI.
  All actionable transitions flow through Connect/Disconnect.

* handleConnect routes NeedsLogin/SessionExpired/LoginFailed to the
  frontend's /login route (which already runs Login → WaitSSOLogin →
  Up) instead of calling Up directly. The plain Up RPC errors with
  "up already in progress: current status NeedsLogin" in those
  states; the legacy Fyne UI drove the SSO dance from the Connect
  button as well.
2026-05-14 14:52:03 +02:00
Zoltan Papp
d841a6aa07 [client] Push status snapshot on every state.Set and classify SSO errors
Two related daemon-side status-stream fixes that together keep the UI's
status in sync with the daemon's contextState:

* state.Set previously only mutated the in-memory enum — transitions
  that weren't accompanied by a Mark{Management,Signal,...} call (e.g.
  StatusNeedsLogin after a PermissionDenied login, StatusLoginFailed
  after OAuth init failure, StatusIdle in the Login defer) left the
  UI stuck on the previous snapshot until an unrelated peer event
  happened to fire notifyStateChange. Add a callback on contextState
  fired from Set (outside the mutex, to avoid lock-order issues with
  the recorder's stateChangeMux), and wire it in Server.Start to the
  recorder's new public NotifyStateChange. Every state.Set callsite
  now pushes automatically; new ones don't need to opt in.

* WaitSSOLogin's WaitToken error branch lumped every failure into
  StatusLoginFailed, including context.Canceled aborts from a parallel
  profile switch (actCancel/waitCancel). That spurious LoginFailed
  then wedged the new profile's Up RPC with "up already in progress:
  current status LoginFailed". Split the branch by error type:
  context.Canceled lets the top-level defer pick StatusIdle,
  context.DeadlineExceeded sets StatusNeedsLogin (retryable; OAuth
  device-code window just expired), other errors keep LoginFailed
  (real auth/IO failures). Document the full state-transition table
  in the function godoc.
2026-05-14 14:51:51 +02:00
Eduard Gert
258e7ec038 Merge branch 'refs/heads/ui-refactor' into ui-refactor-ui
# Conflicts:
#	client/ui/frontend/src/screens/Profiles.tsx
#	client/ui/main.go
2026-05-13 16:51:57 +02:00
Eduard Gert
1932b76f5b update stuff 2026-05-13 16:28:51 +02:00
Zoltan Papp
d33b841a33 [client/ui] Use type conversion for ProfileRef to UpParams (staticcheck) 2026-05-13 16:07:21 +02:00
Zoltan Papp
df1935da6d [client/ui] Regenerate Wails bindings after UpParams and ProfileSwitcher changes 2026-05-13 16:05:46 +02:00
Zoltan Papp
eb6be5a2f3 [client/ui] Always use async Up in the UI service layer
The UI never needs to block on Up — status updates flow via the
SubscribeStatus stream. Hardcode Async:true in Connection.Up and remove
the Async field from UpParams so frontend callers are unaffected.
2026-05-13 16:02:24 +02:00
Zoltan Papp
209f14fc2f [client/ui] Cancel in-flight profile switch on rapid profile changes
Store a switchCancel in Tray. Each switchProfile call cancels the
previous in-flight goroutine before starting a new one. Because gRPC
respects context cancellation, the previous Down/Up RPCs are aborted
and rapid clicks always converge to the last selected profile.
2026-05-13 16:00:31 +02:00
Zoltan Papp
2bd56ecf67 [client/ui] Remove goroutine from ProfileSwitcher.SwitchActive
Down and Up(async=true) are both fast RPCs; no background goroutine
is needed. SwitchActive is now fully synchronous — the tray wraps the
call in its own goroutine, and Wails handles React calls similarly.
2026-05-13 15:55:59 +02:00
Zoltan Papp
67988c2407 [client/ui] Make profile Switch sync, Down+Up async in ProfileSwitcher
Switch RPC errors are now returned synchronously to the caller so the
tray can show a toast immediately on invalid-profile or other early
failures. Down and Up run in a background goroutine so the caller
returns fast; Up still uses async=true so the goroutine is short-lived.
2026-05-13 15:54:33 +02:00
Zoltan Papp
53b2fb8dc1 [client/ui] Add async Up mode to avoid blocking profile switches
The daemon's Up RPC previously always blocked in waitForUp (up to 50s)
until the engine connected. The UI does not need this — status updates
already flow through the SubscribeStatus stream.

Add bool async = 4 to UpRequest. When true the daemon starts
connectWithRetryRuns and returns immediately; the CLI path (async=false,
the default) is unchanged.

ProfileSwitcher.SwitchActive now sets Async:true so all three RPCs
(Status, Switch, Down, Up) return quickly. The background goroutine and
its associated race condition are removed entirely.
2026-05-13 15:51:36 +02:00
Zoltan Papp
803144e569 [client/ui] Unify profile-switching logic in ProfileSwitcher service
Both the tray and the React Profiles page previously had separate
switching logic: the tray applied a status-aware reconnect policy
(Down for error states, Up only when previously Connected/Connecting),
while the React page always called Switch + Up unconditionally with no
Down for LoginFailed/NeedsLogin/SessionExpired.

Introduce a single ProfileSwitcher service that encapsulates the full
reconnect policy. SwitchActive queries the current daemon status, calls
Switch, and launches Down/Up in a background goroutine so the caller
returns immediately after the Switch RPC completes. Both the tray and
the React Profiles page now delegate to this service.

Export the daemon status string constants (StatusConnected, etc.) from
the services package so tray.go no longer duplicates them as private
constants.
2026-05-13 15:46:00 +02:00
Zoltan Papp
c0cd88a3d0 [client/ui] Fix stale LoginFailed/NeedsLogin state after profile switch
When the active profile was in LoginFailed, NeedsLogin, or SessionExpired,
switching to another profile left the daemon holding stale management/signal
errors. The new profile inherited the error state from the previous one.

Two fixes:
1. server.go Down(): reset statusRecorder management/signal errors so the
   next Up() starts with a clean status snapshot instead of the previous
   profile's error state.
2. tray.go switchProfile(): add NeedsLogin/LoginFailed/SessionExpired to
   the needsDown set. Down() is called to flush stale daemon state, but
   Up() is not — the user initiates login on the new profile manually.
2026-05-13 15:13:20 +02:00
Zoltan Papp
6c9b821bf0 [client/ui] Show active profile name and account email in tray menu
The Profiles submenu label now reflects the active profile name instead
of the static "Profiles" text. A disabled email item appears directly
below it in the main menu, matching the legacy Fyne/systray behaviour.

Email is read from the per-profile state file via profilemanager in the
UI process — not through the daemon RPC — because the daemon runs as
root and its getConfigDir() resolves to the root home directory, making
the user-owned state file inaccessible from the daemon side.
2026-05-13 14:13:50 +02:00
Eduard Gert
83030dbbd6 Merge branch 'ui-refactor' into ui-refactor-ui 2026-05-13 10:12:26 +02:00
Eduard Gert
1c8a6e3798 wip 2026-05-13 10:11:38 +02:00
Zoltan Papp
74ea03da9b [ci] Fix Windows installer icon/banner paths missed in ui-wails rename
The ui-wails -> ui rename deleted the fyne installer assets but left the
NSIS and WiX scripts pointing at client/ui/assets/netbird.ico, which broke
the Windows Installer CI job. Point both scripts at the Wails-side icon
(client/ui/build/windows/icon.ico) and restore banner.bmp into the new
build directory so the NSIS welcome/finish sidebar keeps rendering.
2026-05-13 02:28:43 +02:00
Zoltan Papp
77fdf23a50 [ci] Drop Mesa3D opengl32.dll bundling from Windows installer
Wails3 renders via WebView2 on Windows, so the software-OpenGL
fallback needed by the previous Fyne UI is no longer required.
2026-05-13 01:40:16 +02:00
Zoltan Papp
1f4ed5c8ef [ci] Install Wails GTK deps on Linux lint/test runners
Add libwebkit2gtk-4.1-dev and libsoup-3.0-dev to apt installs so the
Wails v3 client/ui package compiles on Linux CI runners.
2026-05-13 01:39:12 +02:00
Zoltan Papp
e1bf362675 [client/ui] Refresh tray menu after status-indicator bitmap change
Wails v3 alpha's setMenuItemBitmap on darwin calls NSMenuItem.setImage
from whichever thread invokes SetBitmap — unlike the sibling setters
for label/disabled/hidden/checked, which dispatch_sync onto the main
queue. The off-thread AppKit call doesn't redraw, so the coloured
status dot stayed stale until the user closed and reopened the menu.

Force a tray.SetMenu rebuild after updating the bitmap; the rebuild
runs processMenu inside InvokeSync, which applies the bitmap to a fresh
NSMenuItem on the main thread and macOS picks it up immediately.
2026-05-12 21:46:05 +02:00
Zoltan Papp
af40ee52f8 [client/ui] Auto-reconnect tray profile switch when daemon was active
Picking a profile from the tray submenu only ran SwitchProfile on the
daemon, so the in-flight retry loop kept dialing the previous profile's
management server. The fix is to follow up Switch with Down+Up, but only
when the daemon was actively trying to be online — Connected or
Connecting. Idle / NeedsLogin / LoginFailed / SessionExpired stay as
deliberate waiting points so a profile pick doesn't surprise the user
with an SSO redirect or flip an intentionally offline daemon online.

The decision table lives in the switchProfile godoc.
2026-05-12 21:40:29 +02:00
Zoltan Papp
4988f2aa68 [client/ui] Refresh Profiles submenu by rebuilding the tray menu
Wails v3 alpha's submenu.Update() builds a fresh, detached NSMenu on
darwin instead of mutating the one attached to the parent menu item at
initial setup, so the visible Profiles entries stayed frozen on the
empty snapshot captured when the tray was registered: clicks reached
the new Go MenuItem objects (and the daemon SwitchProfile RPC ran), but
the checkmark never moved and reopening the menu still showed the old
selection.

Cache the top-level menu and call tray.SetMenu(t.menu) after each
loadProfiles refresh; macosSystemTray.setMenu clears and rebuilds the
entire NSMenu tree against the cached pointer, which propagates submenu
content changes to the visible menu.

Also adds INFO logs around profile click / SwitchProfile RPC / list
refresh so the active-profile flow is observable end-to-end.
2026-05-12 21:24:52 +02:00
Zoltan Papp
e3efaa5e59 [client] Fix tray flicker and stuck Connecting during management retry
The status snapshot tore down on every management retry because
state.Status() blanks the status when an error is wrapped, and the
SubscribeStatus stream propagated that as FailedPrecondition. The UI
treated any stream error as "daemon not running" and flickered the tray
to Not running between retries.

Disconnect was also unresponsive: Down set Idle before the retry
goroutine exited, which then overwrote it with Set(Connecting) on the
next attempt; the backoff sleep (up to 15s) wasn't context-aware, so the
goroutine kept running long after actCancel.

- buildStatusResponse falls back to the underlying status (via new
  state.CurrentStatus) instead of breaking the stream on wrapped errors.
- UI only flips to DaemonUnavailable on codes.Unavailable / non-status
  errors, so a live daemon returning FailedPrecondition is not reported
  as down.
- connect retry uses backoff.WithContext so actCancel interrupts the
  inter-attempt sleep, and skips Wrap(err) when the dial fails due to
  ctx cancellation.
- Down sets Idle after waiting for giveUpChan, so the retry goroutine
  can no longer race the disconnect.
- Tray hides Connect during Connecting and keeps Disconnect enabled so
  the user can abort an in-flight connection attempt.
2026-05-12 20:38:30 +02:00
Zoltan Papp
100d25a062 [client/ui] Add Profiles submenu to the tray
Mirror the main branch's profile list: a Profiles submenu populated
from the daemon's ListProfiles RPC, with the active profile shown as
a checked entry and a click on any other entry switching to it via
SwitchProfile.

The initial fill is deferred to the Common.ApplicationStarted hook
because Menu.Update() short-circuits while app.running is false and
the Wails3 macOS impl would nil-deref on early-startup InvokeSync.
2026-05-12 20:11:08 +02:00
Zoltan Papp
04b4330393 [client/ui] Add coloured status dot to tray menu
Show a small dot next to the first menu entry that reflects the
daemon state: green for Connected, yellow for Connecting, blue for
NeedsLogin/SessionExpired, red for LoginFailed/Error, grey for
Idle/Disconnected and dark grey for DaemonUnavailable. PNGs are 24x24
with a pHYs chunk declaring 144 DPI so NSImage renders them at 12 pt
while keeping retina-sharp pixel data; circles are supersampled 8x for
smooth edges.

Idle now surfaces as "Disconnected" in the menu label, daemon-status
literals moved to status* constants, and Exit Node / Resources are
gated on the Connected state instead of just daemon availability.
2026-05-12 20:05:50 +02:00
Eduard Gert
c8e18585c6 add update context 2026-05-11 17:21:38 +02:00
Eduard Gert
1931a2c8a8 add update available icon 2026-05-11 17:11:25 +02:00
Eduard Gert
108d43e702 add flags, update peers list 2026-05-11 16:17:54 +02:00
Eduard Gert
842ef0d657 update macos icon 2026-05-11 15:40:04 +02:00
Eduard Gert
439f44c6b4 merge 2026-05-11 15:16:41 +02:00
Eduard Gert
b5a970155b Merge branch 'ui-refactor' into ui-refactor-ui 2026-05-11 15:15:11 +02:00
Eduard Gert
686e0d97f2 update Assets.car 2026-05-11 14:51:05 +02:00
Eduard Gert
0c287b6f4d fix vite dev server 2026-05-11 14:48:37 +02:00
Eduard Gert
f7f5946910 update components 2026-05-11 14:26:10 +02:00
Zoltan Papp
7a9f5a734f Merge branch 'main' into ui-refactor
Port IPv6 overlay support (#5631) into the Wails UI:
- Add DisableIPv6 config toggle to Settings (NetworkTab + services)
- Filter ::/0 alongside 0.0.0.0/0 as an exit-node route
- Suppress duplicate v6 default-route notifications in tray
2026-05-11 14:10:12 +02:00
Eduard Gert
1aae067aaa add settings skeleton 2026-05-11 13:58:41 +02:00
Zoltan Papp
28a7eba756 [client/ui] Remove unused xembed_host_other.go stub 2026-05-11 13:54:17 +02:00
Zoltan Papp
8841b950a2 [client/server] Stop retry loop after PermissionDenied login
Without marking the error as backoff.Permanent the outer retry re-enters
connect(), which resets the daemon state from NeedsLogin to Connecting
and makes the tray flicker between the two until the user logs in.
2026-05-11 13:43:53 +02:00
Eduard Gert
0c2702c0d7 update height and wording 2026-05-11 13:30:05 +02:00
Zoltan Papp
b43a09a1c7 [client/ui] Add tray icon for needs-login/login-failed states
The tray now switches to a dedicated lock icon when the daemon reports
NeedsLogin, SessionExpired or LoginFailed — the latter mirrors the CLI,
which groups these three statuses together as "needs authentication"
and prints the same "Run netbird up" prompt. The macOS template variant
reuses the existing error-macos PNG because the project's macOS tray
PNGs use a 2-color (black + transparent) convention that rsvg-convert
of the badge-style SVG sources can't reproduce. The earlier badge-style
SVG sketches in assets/svg/ are removed (they were marked as reference
only and never matched the shipping PNG design).
2026-05-11 13:22:30 +02:00
Zoltan Papp
595dfbb6f1 [client/ui] Distinguish "daemon not running" tray state
The status stream emits a synthetic StatusDaemonUnavailable when the
gRPC client or stream cannot be established, fired once per outage and
cleared on the next real snapshot. The tray maps it to a "Not running"
status label, switches the icon to the error variant, hides
Connect/Disconnect (neither would work without the daemon), and
disables Settings, Networks and Create Debug Bundle so the user is not
routed to pages that would just fail to load.
2026-05-11 12:22:47 +02:00
Zoltan Papp
7f560df9be [client/ui] Tray menu opens on click; hide window at startup
Left-click on the tray icon now opens the menu on every platform — the
window is reached through a new "Open NetBird" entry. Only the action
that matches the current daemon state is shown: Connect when
disconnected, Disconnect when connected. The main window starts hidden
and is only surfaced via the tray, single-instance launch, or daemon
events.
2026-05-11 12:01:46 +02:00
Zoltán Papp
09052949a2 [client/ui] Finish ui-wails rename (import paths, fyne deps)
Follow-up to the rename commit: the previous commit moved the files but
the post-mv string substitutions (Go imports, frontend bindings, CI
config paths) were not re-staged so they slipped through. This commit
applies those edits and removes the fyne dependencies from go.mod/go.sum
now that the legacy fyne UI is gone.
2026-05-11 11:33:35 +02:00
Zoltán Papp
9aef31ff53 [client/ui] Replace fyne UI with Wails (rename ui-wails to ui)
Removes the legacy fyne-based client/ui implementation and renames the
Wails replacement (client/ui-wails) to take its place at client/ui. Go
imports, frontend bindings, CI workflows, goreleaser configs and the
windows .syso icon path are updated to follow the rename.
2026-05-11 11:20:22 +02:00
Zoltán Papp
08f52f4517 [client/server] Allow clearing pre-shared key via SetConfig
The daemon ignored an empty OptionalPreSharedKey, so a UI/CLI request to
clear the pre-shared key was silently dropped. Pass the pointer through
unconditionally — profilemanager already handles the empty-string case.
2026-05-11 11:02:39 +02:00
Eduard Gert
18e3b5dd32 fix about 2026-05-11 09:37:14 +02:00
Eduard Gert
f3f9704c6f update about 2026-05-08 17:55:41 +02:00
Eduard Gert
4c3d4effbd update troubleshooting 2026-05-08 17:18:25 +02:00
Eduard Gert
3953fee5a4 update ssh and advanced settings tabs 2026-05-08 10:57:31 +02:00
Eduard Gert
adeaa49cda update switch 2026-05-07 17:27:56 +02:00
Eduard Gert
2c5d52a1bf update wording 2026-05-07 17:19:56 +02:00
Eduard Gert
70a755fbae add general settings 2026-05-07 16:47:52 +02:00
Eduard Gert
559da5d5b9 refactor 2026-05-07 15:00:36 +02:00
Eduard Gert
614ee11ac7 update CFBundleDisplayName 2026-05-07 14:19:34 +02:00
Eduard Gert
85080afa59 use new mac style icons 2026-05-07 14:14:26 +02:00
Zoltán Papp
a5cc8da054 [client] Pre-seed CustomActivator CLSID under HKCU AppUserModelId\NetBird
The Wails notifications service reads HKCU\Software\Classes\AppUserModelId\
<AppName>\CustomActivator on first startup; if present it uses that GUID
as the toast activator CLSID, otherwise it generates a fresh UUID and
writes it back. Without an installer-supplied value the per-machine GUID
diverges from the ToastActivatorCLSID baked into the Start Menu and
Desktop shortcuts, and the COM activator never fires when a toast is
clicked. Seed the same CLSID the shortcuts use so the two sides match.
2026-05-07 13:00:51 +02:00
Zoltán Papp
a4fd5a78b4 [client/ui-wails] Set application Name to "NetBird" for Windows toasts
Windows uses application.Options.Name as the toast AppUserModelID and as
the registry path the Wails notifier reads/writes its CustomActivator
under (HKCU\Software\Classes\AppUserModelId\<Name>). The MSI installer
seeds those under "NetBird"; with the previous "netbird-ui" Name the app
would have written under a different identity and the toast activator
CLSID the installer pre-registers would have been orphaned.
2026-05-07 12:59:01 +02:00
Eduard Gert
062a183e4e update settings nav 2026-05-07 12:40:04 +02:00
Eduard Gert
a2be41caf8 add about setting 2026-05-07 11:24:11 +02:00
Zoltán Papp
5b70989e3e [client/ui-wails] Make /update page faithful to the legacy auto-update dialog
Adds the missing info line ("Your client version is older than the
auto-update version set in Management. Updating client to: <version>.")
and replaces the spinner with the legacy 1-second dot animation
(Updating./.../...). Terminal-state wording now matches the Fyne UI
exactly: 15 min timeout, canceled, and "Update failed: <err>".

Ports mapInstallError from client/ui/update.go so daemon errors that
embed "deadline exceeded" / "canceled" hit the right branch instead of
falling through as a generic failure.

Detects the daemon dropping mid-upgrade (the legacy success signal):
if GetInstallerResult fails for 5s straight, call the new Update.Quit
service method to exit, mirroring app.Quit() in showInstallerResult.
2026-05-07 10:35:18 +02:00
Zoltán Papp
d324a5ff48 [ci] Stub frontend/dist before lint so the Wails embed pattern matches
client/ui-wails/main.go embeds all:frontend/dist, which is produced by
the frontend build and gitignored. Lint runs don't build the frontend,
so the directory is missing in CI and golangci-lint fails the typecheck.
Create a placeholder file before linting so the embed has something to
match.
2026-05-07 10:23:02 +02:00
Eduard Gert
debb558aa3 wip 2026-05-07 09:57:14 +02:00
Zoltán Papp
cce80f8276 [client/ui-wails] Drop dead freebsd branches in services/connection.go
The file's build constraint excludes freebsd, so the freebsd cases in
IsUnixDesktopClient and OpenURL were unreachable — staticcheck (SA4032)
fails the pre-push lint. Linux is the only Unix-desktop GOOS this
package compiles for, so collapse both checks accordingly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 18:00:51 +02:00
Zoltán Papp
05ee4e52b8 [client/ui-wails] Make the SSO login flow recoverable from a stuck state
A pending WaitSSOLogin parks the daemon on an OAuth UserCode forever
once the user closes the browser without completing the flow. The
frontend can't unblock that on its own — it needs the daemon to fire
its own actCancel(). Three fixes work together:

- Login() now issues a Down() before kicking off the new flow so a
  previously-stuck WaitSSOLogin is unwedged before we ask the daemon
  for fresh OAuth info.
- The Login page's Cancel button calls Down() before navigating away,
  so abandoning the flow mid-browser actually settles the daemon's
  in-flight WaitSSOLogin instead of leaving it pinned.
- Status keeps the Login button visible whenever we aren't Connected
  (including Connecting), so a UI restart that finds the daemon stuck
  in Connecting still has a one-click recovery path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 17:59:50 +02:00
Zoltán Papp
bb2bf673a0 [client/ui-wails] Wire up the SSO login flow end-to-end
Mirror the Fyne client's login path: the daemon Login RPC now defaults
ProfileName/Username from GetActiveProfile + the OS user and sets
IsUnixDesktopClient on Linux/FreeBSD so the daemon picks the SSO
browser flow. A new OpenURL service launches the user's default
browser via xdg-open / open / rundll32 (Fyne's openURL helper) — the
embedded WebKit's window.open silently fails for external URLs.

The frontend gains a Login page that drives the full Login →
window.open via OpenURL → WaitSSOLogin → Up sequence with progress
states. Status surfaces a Login button while the daemon reports
NeedsLogin/SessionExpired, and the tray's status row stops being a
purely-decorative label: it becomes a clickable Login entry whenever
re-authentication is required.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 17:48:47 +02:00
Zoltán Papp
91c745e5e8 [client/ui-wails] Tear down the whole tray popup tree on focus loss
Replace the per-submenu focus-out handler with a shared idle-deferred
recheck: when any popup loses focus, ask after the next event-loop
turn whether *any* of our popups still owns toplevel focus. If none
does, the user clicked outside the menu tree, so close every popup at
once instead of leaking the parent.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 17:19:55 +02:00
Zoltán Papp
68c38247f1 [client/ui-wails] Add submenu support to the XEmbed tray popup
Recursively walk dbusmenu children-display="submenu" entries when
flattening the SNI menu so the GTK popup can render nested items.
The C side renders submenu folders as labeled buttons that open a
child popup window aligned to the anchor row, kept on-screen with
horizontal flipping; the top-level popup no longer self-destructs
when focus transfers to one of its own submenus.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 17:17:54 +02:00
Zoltán Papp
8b8f38de1b [client/ui-wails] Show GUI and daemon versions in the About submenu
Restore the legacy Fyne UI's two disabled "GUI: x.y.z" / "Daemon: a.b.c"
entries under About so users (and support) can read the running
versions from the tray. The GUI line is baked in at build time via
version.NetbirdVersion() — the same -ldflags chain the rest of the
repo uses. The daemon line starts as "—" and is rewritten in
applyStatus on every Status snapshot whose DaemonVersion differs from
the last one we recorded, so a daemon restart with a new build
(e.g. after an enforced update) updates the menu automatically.

Drive-by: rename the local variable that shadowed the version package
in handleUpdate so the import resolves cleanly.
2026-05-06 16:55:52 +02:00
Zoltán Papp
2b272e74c8 [client/ui-wails] In-process StatusNotifierWatcher + XEmbed tray bridge
Wails3's Linux systray hands the icon off to whatever process owns
org.kde.StatusNotifierWatcher on the session bus. Bare WMs (Fluxbox,
OpenBox, i3, dwm, sway, vanilla GNOME without the AppIndicator
extension) ship no watcher, so the icon registration silently fails
and the tray never appears — leaving a tray-only app like NetBird
unreachable.

Add a Linux-only watcher fallback that claims the watcher name when
nobody else does, plus an XEmbed bridge so legacy X11 system trays
(_NET_SYSTEM_TRAY_S0) can still render the icon. Both no-op on other
platforms via build tags.

Pieces:
- tray_watcher_linux.go: claims org.kde.StatusNotifierWatcher on a
  private session bus, exports the bare RegisterStatusNotifierItem /
  RegisterStatusNotifierHost surface, and spins up an XEmbed host per
  registered SNI item.
- xembed_host_linux.go: per-item event loop. Polls X11 events with a
  50ms ticker, listens for the SNI NewIcon signal, dispatches Activate
  / context menu through dbusmenu (com.canonical.dbusmenu).
- xembed_tray_linux.{c,h}: the X11/cairo native bits. Window is created
  with CopyFromParent visual + ParentRelative background so transparent
  pixels show the toolbar beneath instead of solid black on 24-bit
  trays. cairo paints the IconPixmap with OVER blending so per-pixel
  alpha is honoured against the parent-relative base. GTK3 owns the
  context-menu popup; menu items round-trip through dbusmenu Event.
- tray_linux.go: forces WEBKIT_DISABLE_DMABUF_RENDERER=1 in init() so
  developers running `task dev` / launching the binary directly get the
  same software rendering path the .desktop launcher already enables;
  the deb/rpm Exec wrapper covers installed users.
- tray_watcher_other.go and xembed_host_other.go: build-tag stubs so
  main.go's startStatusNotifierWatcher() compiles on every platform.
- main.go: calls startStatusNotifierWatcher() before NewTray so the
  Wails systray's RegisterStatusNotifierItem call hits a watcher we
  control on bare WMs.
- build/linux/netbird-ui.desktop: regenerated by `task build` to wrap
  the dev launcher's Exec line with the WEBKIT_DISABLE_DMABUF_RENDERER
  env, matching what the tray_linux.go init does at runtime.

Adapted from work originally prototyped on the prototype/ui-wails branch.

Tested on Fluxbox (Debian 13): the icon appears in the slit/toolbar with
the toolbar's background showing through transparent pixels, left-click
opens the window, right-click brings up the GTK popup of the dbusmenu
items.
2026-05-06 16:47:35 +02:00
Zoltán Papp
e6cbf30415 [client/ui-wails] Surface daemon SessionExpired in the tray
Port the Fyne UI's onSessionExpire 1:1 to the Wails tray so an SSO token
expiry no longer leaves the user staring at a stale peer list. When
applyStatus sees the transition into the daemon's StatusSessionExpired,
fire a single OS notification (the lastStatus guard rate-limits it to
the transition itself, mirroring the Fyne sendNotification flag) and
bring the main window forward on the /login route so the frontend can
drive the renewed SSO flow. The Fyne client achieved the same end with
a runSelfCommand "login-url" helper; here the window is already
in-process so we route to it directly.
2026-05-06 15:57:34 +02:00
Zoltán Papp
490b60ad0e [ci] Suppress typecheck on the ui-wails embed instead of skipping main.go
The previous attempt added client/ui-wails/main.go to the file path
exclude list, but golangci-lint v2's path filter only suppresses
issues from rule-based linters; the typecheck pre-pass that compiles
the package still runs and fails with "pattern all:frontend/dist: no
matching files found" before any rule fires.

Replace the path-level skip with a targeted exclusions.rules entry
that matches just that diagnostic on just that file. The rest of
client/ui-wails (services/, tray.go, grpc.go, ...) keeps being linted
normally.

Validated locally by deleting frontend/dist and running
`golangci-lint run client/ui-wails/...` — 0 issues with this config.
2026-05-06 15:50:14 +02:00
Eduard Gert
553be144b4 add setting 2026-05-06 14:21:01 +02:00
Eduard Gert
c3f9514182 wip 2026-05-06 10:47:40 +02:00
Zoltán Papp
a8812d5fb1 Merge remote-tracking branch 'origin/main' into ui-refactor
# Conflicts:
#	go.mod
#	go.sum
2026-05-05 15:41:59 +02:00
Zoltán Papp
6f93cf6ac3 [client/ui-wails] Group Tray's services into a TrayServices struct
NewTray's eight-parameter signature crossed Sonar's seven-parameter
threshold once Update joined the dependency list. Bundle the six service
pointers (Connection, Settings, Profiles, Peers, Notifier, Update) into
a TrayServices struct, leaving NewTray with three arguments — the two
Wails platform handles plus the service bag. Tray.svc replaces the
individual fields; call sites use t.svc.Connection etc.

Adding another service later is now a one-line struct change instead
of a NewTray signature break.
2026-05-05 15:37:25 +02:00
Zoltán Papp
18909390c2 [ci] Use go list -e so the ui-wails embed doesn't blank the test list
The previous fix added /client/ui-wails to the grep -v / Where-Object
filter, but go list aborts at the first broken package and emits an
empty stdout when client/ui-wails/main.go's //go:embed all:frontend/dist
fails to resolve. The command substitution then expands to nothing, and
`go test` falls back to the repo root — which has no Go files and fails
the job.

`go list -e` keeps listing remaining packages after a parse error, so
the existing path-based filter now actually does its job.

Touches all three test workflows (Linux native + docker, Darwin, Windows).
2026-05-05 15:30:40 +02:00
Zoltán Papp
b3eb5f2453 [ci] Skip lockfiles in codespell
pnpm-lock.yaml and package-lock.json embed package hashes that look
like English words to codespell (e.g. "nD" -> "and"), causing false
positives that can't be fixed because the lockfile is auto-generated.
Add the standard lockfile patterns to the skip list alongside the
existing go.mod/go.sum/proxy-web entries.
2026-05-05 15:15:15 +02:00
Zoltán Papp
dc02542a9e [ci] Skip client/ui-wails/main.go in golangci-lint
main.go uses //go:embed all:frontend/dist, which fails the typecheck
phase when frontend/dist is empty (the release pipeline populates it
via `pnpm build`; the lint workflow does not). Excluding just main.go
keeps the rest of the package — services/, tray.go, grpc.go, the
signal handlers — in scope.
2026-05-05 15:12:49 +02:00
Zoltán Papp
0c136fffb9 [ci] Add sonar-project.properties to exclude the Wails React frontend
Sonar's default scanner picks up TypeScript / JSX from the frontend
tree but applies rules that don't fit a UI codebase reviewed visually
(component dead-code detection, hook-shape conventions, ...). Skip
client/ui-wails/frontend from both analysis and coverage so neither
the rules engine nor the coverage gate trips on UI changes.

The Go side of the Wails UI (client/ui-wails/*.go, services/) is left
in scope on purpose — same Go standards as the rest of the repo.
2026-05-05 15:10:23 +02:00
Zoltán Papp
fffb9dd219 [client/ui-wails] Add Forwarding service for the exposed-services list
Surfaces the daemon's existing ForwardingRules RPC as a Wails service so
the React frontend can render the reverse-proxy / exposed-services list
in the planned dashboard.

Forwarding.List() returns one ForwardingRule per active rule with
protocol, destination port (single or range), translated address /
hostname, and translated port. The PortInfo oneof from the proto is
flattened to a `{port?: number, range?: {start, end}}` shape so TS
consumers don't have to peek at proto-internal type discriminators.

Regenerate frontend/bindings (forwarding.ts, models.ts, index.ts) so
the React side picks up the new service. peers.ts churn is a doc
comment refresh only — no API change.
2026-05-05 13:53:40 +02:00
Zoltán Papp
93275f9052 Bump github.com/wailsapp/wails/v3 to v3.0.0-alpha.84
Picks up the alpha.84 patch series. The only API change relative to
alpha.78 is a new macOS Liquid Glass effect option (NSGlassEffectView)
that NetBird does not use, so this is a drop-in dependency bump.

netbird-ui builds cleanly, go vet has no new findings, and the existing
Linux tray workaround (skip AttachWindow + OnClick on Linux) is still
required — Wails3 systemtray_linux.go's openMenu remains a "not
implemented on Linux" stub and SystemTray.applySmartDefaults still
auto-installs ToggleWindow as the click handler when a window is
attached.

The alpha CLI's transitive github.com/goreleaser/nfpm/v2 v2.44.1 is not
imported by any NetBird production binary (verified with `go list -deps`
on netbird-ui and the daemon entry points); it only ships inside the
wails3 developer CLI used for local packaging. The Snyk advisory for
nfpm therefore does not affect netbird-ui or the daemon.
2026-05-05 13:09:37 +02:00
Zoltán Papp
dd9c15072f [ci] Skip client/ui-wails in go test runs
main.go embeds frontend/dist with //go:embed, so any go-list-based test
sweep that touches the package fails at compile time before pnpm build
has populated the directory. The release pipeline runs the frontend
build via the goreleaser before-hook; the test workflows do not, and
should not, ship a Node toolchain just to compile a UI binary that has
no Go-side unit tests anyway.

Add a /client/ui-wails exclude to the test go-list filter on Linux,
Darwin and Windows.
2026-05-05 12:56:59 +02:00
Zoltán Papp
4c743bc03d Merge remote-tracking branch 'origin/main' into ui-refactor
# Conflicts:
#	client/internal/peer/status.go
#	client/proto/daemon.pb.go
#	client/proto/daemon_grpc.pb.go
#	go.mod
2026-05-05 12:49:09 +02:00
Zoltán Papp
2e61b42e92 [client/ui-wails] Slim the tray menu, move toggles to Settings page
The Fyne 1:1 tray pulled the entire daemon-config knobset (Allow SSH,
Connect on Startup, Quantum-Resistance, Lazy Connections, Block Inbound,
Notifications) into a Settings submenu — useful in a tray-only UI but
redundant now that the Wails app has a real Settings page. Drop the
submenu and route a single top-level "Settings" entry to /settings;
"Create Debug Bundle" stays at the top level for support workflows.

Side effects:
  - flipFlag and ptrBool go away with the checkbox callbacks.
  - loadConfig keeps seeding notificationsEnabled (the tray still gates
    OS toasts in onSystemEvent on it) but no longer mirrors any other
    config field.
  - Unused menu/notify constants (Allow SSH, Connect on Startup, ...,
    notifyErrorSettingsFmt) are removed from the central const block.
2026-05-05 12:19:41 +02:00
Zoltán Papp
3f8de2a149 [client/ui-wails] Hide Dock entry on macOS via LSUIElement
The legacy Fyne client and the sign-pipelines-built .pkg both run NetBird
in macOS Accessory mode (LSUIElement=1) — tray-only, no Dock entry, no
Cmd-Tab presence. The Wails build's bundled Info.plist (used by `task
darwin:package` for local development) didn't carry the flag, so the
.app bundle a developer builds locally diverged from the signed release.

Add LSUIElement to both Info.plist and Info.dev.plist so the local dev
flow matches what users see.
2026-05-05 12:03:09 +02:00
Zoltán Papp
bc609c3ae7 [client/ui-wails] Wire up enforced-update tray menu item
Surface the Fyne UI's "Download latest version" / "Install version X.Y.Z"
About-submenu entry in the Wails tray. The item starts hidden and is
revealed by onUpdateAvailable when the daemon emits EventUpdateAvailable;
opt-in updates open github.com/netbirdio/netbird/releases/latest in the
browser, enforced updates surface the in-window /update progress page
and call TriggerUpdate on the daemon.

Also lift every user-facing string and external URL in tray.go into
named const declarations at the top of the file, so future copy edits
and (eventual) localisation have a single source of truth.

The /update React route is the frontend counterpart and is owned by the
React side of the refactor.
2026-05-05 11:56:57 +02:00
Zoltán Papp
e3994d0c99 [client] Drop Mesa3D opengl32.dll, bootstrap WebView2 in Windows installers
Wails3 uses the WebKit-style WebView2 runtime instead of Fyne's OpenGL
backend, so the Mesa3D opengl32.dll payload that the Fyne build needed
for RDP/VM rendering can leave the .exe and .msi installers. Add a
WebView2 bootstrap step that probes the EdgeUpdate registry markers
(both HKLM\WOW6432Node and HKCU) and silently runs
MicrosoftEdgeWebview2Setup.exe only if the runtime is missing.

NSIS uses an inline macro adapted from Wails3's wails_tools.nsh; WiX
uses a deferred CustomAction gated on RegistrySearch properties. Both
expect the bootstrapper payload at client/MicrosoftEdgeWebview2Setup.exe,
which the sign-pipelines build step generates with `wails3 generate
webview2bootstrapper`. The matching sign-pipelines change lives in
that repo's PR.

The uninstall section keeps an unconditional `Delete opengl32.dll` so
upgrades from older Fyne builds clean up the leftover file.
2026-05-04 17:36:30 +02:00
Zoltán Papp
ba6e10cef3 [client/ui-wails] Pad macOS tray PNGs for proper menubar sizing
Wails3's macOS systray sets the NSImage size to the status bar thickness
(~22pt) on a square frame. The legacy Fyne PNGs had almost no horizontal
margin (the logo filled all 256x256), so under that explicit resize the
glyph stretched to the full menubar height — noticeably larger than
neighbouring SF Symbols-style indicators.

Pad each *-macos.png from 256x256 to 366x366 with transparent gravity:center
extent, leaving the glyph at ~70% of the rendered size. Same source PNGs,
no resampling, just more breathing room around the alpha-only template.
2026-05-04 17:12:12 +02:00
Zoltán Papp
ce53981b55 [client/ui-wails] Fix Windows manifest version format
Win32 assembly manifests require a four-part version (MAJOR.MINOR.BUILD.REVISION
per the Microsoft schema). The Wails template shipped a three-part "0.0.1",
which Windows rejects with "Activation context generation failed (...) The
value 0.0.1 of attribute version in element assemblyIdentity is invalid",
so the .exe never reaches main(). Pad to "0.0.1.0".
2026-05-04 16:20:15 +02:00
Zoltán Papp
a69037630b [client/ui-wails] Skip tray click-to-toggle on Linux
GNOME Shell + AppIndicator extension opens the attached menu on
left-click in addition to firing SNI Activate, so binding the window
toggle to the click handler made both the window and the menu pop on a
single click. The default Wails3 SystemTray.applySmartDefaults made it
worse: AttachWindow alone is enough to install ToggleWindow as the
implicit click handler, so dropping OnClick wasn't sufficient.

Mirror the legacy Fyne client: skip both AttachWindow and OnClick on
Linux and expose the main window through an explicit "Open NetBird"
menu item. Windows and macOS keep the click-to-toggle behaviour where
the OS cleanly separates left and right click.
2026-05-04 16:08:10 +02:00
Zoltán Papp
df58935cc0 [client/ui-wails] Set NetBird window and app icon on Linux
Wails3 falls back to its bundled bird logo when no Icon is supplied via
application.Options or LinuxWindow. Embed the 256x256 NetBird PNG and
wire it through both fields, plus set ProgramName=netbird so GTK's
g_set_prgname picks up the icon from the installed .desktop file. Tested
on Fedora 40 + KDE Plasma; the titlebar and taskbar now show the NetBird
logo.
2026-05-04 14:34:45 +02:00
Zoltán Papp
a1743dbf9b [client/ui-wails] Fix Fedora ayatana-appindicator package name
The RPM dependency name on Fedora is libayatana-appindicator-gtk3 (not
libayatana-appindicator3 — that's the Debian/Ubuntu spelling). Verified
with dnf install on Fedora 40.
2026-05-04 14:00:52 +02:00
Zoltán Papp
f9771de3f5 [client/ui-wails] Switch release pipelines from Fyne to Wails UI
Repoint goreleaser configs and the release workflow at client/ui-wails so
the published Linux deb/rpm, Windows binaries and macOS UI binaries are
built from the Wails source. Linux nfpm deps swap libappindicator/Fyne
GL stack for libgtk-3, libwebkit2gtk-4.1 and libayatana-appindicator3,
and the packaged .desktop file launches the binary with
WEBKIT_DISABLE_DMABUF_RENDERER=1 so RDP/VM sessions render correctly.
Frontend bindings are now committed; the release jobs add Node 20 and
pnpm 9 and run the frontend build via the goreleaser before-hook.
2026-05-04 13:00:13 +02:00
Eduard Gert
bfe19fa542 wip 2026-05-04 10:15:29 +02:00
Eduard Gert
d07f25fc49 wip 2026-05-04 10:14:41 +02:00
Eduard Gert
670b0f66ac Merge branch 'ui-refactor' into ui-refactor-ui 2026-04-30 14:57:32 +02:00
Eduard Gert
15d73a2edd Add connect toggle 2026-04-30 13:22:43 +02:00
Zoltán Papp
88a2bf582d [client] Push-based status stream for the Wails UI
Adds a SubscribeStatus gRPC RPC that pushes a fresh FullStatus snapshot
on every peer-recorder state change, replacing the Wails UI's 2-second
Status poll. The daemon's notifier already triggers on Connected /
Disconnected / Connecting / management or signal flip / address
change / peers-list change; we now coalesce those into ticks on a
buffered chan and stream the resulting snapshots over gRPC.

- Status recorder gains SubscribeToStateChanges /
  UnsubscribeFromStateChanges + a non-blocking notifyStateChange that
  drops ticks when a subscriber's 1-slot buffer is full (next snapshot
  the consumer pulls already reflects everything).
- Server.Status handler split: the snapshot composition is shared
  with the new SubscribeStatus stream handler so unary and stream
  paths return identical bytes.
- UI peers service: pollLoop replaced by statusStreamLoop. The local
  name of the existing SubscribeEvents loop is now toastStreamLoop so
  the two streams are easy to tell apart — the underlying RPC name is
  unchanged.
- Tray applyStatus skips the icon refresh when connected/lastStatus
  hasn't changed; rapid SubscribeStatus bursts during health probes
  no longer churn Shell_NotifyIcon or the log.
2026-04-30 11:45:43 +02:00
Zoltán Papp
0148d926d5 [client/ui-wails] Use original Fyne tray PNGs and drop the .ico split
The SVG-derived tray icons + multi-resolution .ico path looked correct on
disk but Wails3's Shell_NotifyIcon update never landed on the running
Windows tray — the icon stayed frozen on the .exe resource regardless of
how many times we called SetIcon. Single-PNG fed through the same path
updates correctly, so revert to the source-of-truth PNGs that ship with
the legacy Fyne UI and remove the icons_windows.go / tray_icon_*.go
split. The 6 colored tray PNGs and 6 macOS-template PNGs come from
client/ui/assets verbatim. Generation pipeline (assets/svg/) is gone.
2026-04-29 18:54:51 +02:00
Zoltán Papp
8f16a19b8f [client/ui-wails] Add windows:build:console task for log debugging
The default Windows build links the binary as a GUI subsystem app, so
stdout/stderr is detached from the launching terminal — invisible logrus
output makes tray and event-stream bugs hard to chase. Add a sibling task
that links as console subsystem and writes a separately-named binary so
the production output is preserved.

Usage:
  CGO_ENABLED=1 task windows:build:console
  bin\netbird-ui-console.exe   # logs print to the launching cmd/PowerShell
2026-04-29 16:21:45 +02:00
Zoltán Papp
504dceedf3 [client] Add Wails3 + React desktop UI scaffold
Stage 1 of the client/ui (Fyne) replacement. Adds a new client/ui-wails
module that runs on Linux/macOS/Windows from a single React + Vite +
Tailwind frontend driven by a thin gRPC services layer in Go.

- Single-module integration (no submodule): merge Wails3 into root go.mod
  with build tags !android !ios !freebsd !js so cross-compiles on those
  targets exclude the package automatically.
- Seven gRPC-bound services: Connection, Settings, Networks, Profiles,
  Debug, Update, Peers. Peers bridges Status polling and SubscribeEvents
  to the Wails event bus (netbird:status, netbird:event).
- Tray + window shell mirrors the Fyne menu 1:1 with hide-on-close,
  SIGUSR1 / Windows named-event for external "show window" triggers.
- React pages cover functional parity for Status, Settings (3 tabs),
  Networks (3 tabs), Profiles, Debug, Update, QuickActions, LoginUrl.
- SVG-sourced tray icons (12 source SVGs incl. macOS template variants)
  rasterized to PNG via task common:generate:tray:icons.
- Linux launcher sets WEBKIT_DISABLE_DMABUF_RENDERER=1 in the .desktop
  Exec= line and in task linux:run so the app renders correctly under
  RDP, VirtualBox, KVM, and bare WMs (Fluxbox/dwm) without DRM access.
2026-04-29 11:10:23 +02:00
449 changed files with 36370 additions and 15576 deletions

View File

@@ -45,11 +45,13 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the grep then drops the broken package by path. Without -e,
# go list aborts with empty stdout and `go test` falls back to the repo
# root, which has no Go files.
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,client

View File

@@ -53,7 +53,7 @@ jobs:
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
@@ -145,7 +145,7 @@ jobs:
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
@@ -158,15 +158,15 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags "devcert integration" -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,client
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the grep then drops the broken package by path. Without -e,
# go list aborts with empty stdout and `go test` falls back to the repo
# root, which has no Go files.
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list -e ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -228,7 +228,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags "devcert integration" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -e -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
'
test_relay:
@@ -284,17 +284,9 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' -coverprofile=coverage.txt \
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,relay
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
@@ -342,15 +334,7 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 -coverprofile=coverage.txt ./proxy/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,proxy
go test -timeout 10m -p 1 ./proxy/...
test_signal:
name: "Signal / Unit"
@@ -401,17 +385,9 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' -coverprofile=coverage.txt \
-exec 'sudo' \
-timeout 10m ./signal/... ./shared/signal/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,signal
test_management:
name: "Management / Unit"
needs: [build-cache]
@@ -477,18 +453,10 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert -coverprofile=coverage.txt \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/... ./shared/management/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,management
benchmark:
name: "Management / Benchmark"
needs: [build-cache]
@@ -727,14 +695,6 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=integration -coverprofile=coverage.txt \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/server/http/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: integration,management

View File

@@ -65,8 +65,15 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- name: Generate test script
# Exclude client/ui: its main.go uses //go:embed all:frontend/dist,
# which fails to compile until the frontend has been built. The Wails UI
# has no Go-side unit tests, and its release pipeline runs `pnpm build`
# before goreleaser.
# `go list -e` lets the listing succeed even though the embed fails to
# resolve; the Where-Object pipeline then drops the broken package by
# path. Without -e, go list aborts with empty stdout.
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$packages = go list -e ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' } | Where-Object { $_ -notmatch '/client/ui' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd

View File

@@ -22,7 +22,11 @@ jobs:
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**
# Non-English UI translations trip codespell on real foreign words
# (de: "Sie", "oder", "ist"). Only en/common.json is the source of
# truth that should be spell-checked. Add each new locale dir here
# when a language is added under client/ui/i18n/locales/.
skip: go.mod,go.sum,**/proxy/web/**,**/pnpm-lock.yaml,**/package-lock.json,client/ui/i18n/locales/de/**,client/ui/i18n/locales/hu/**
golangci:
strategy:
fail-fast: false
@@ -54,7 +58,16 @@ jobs:
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Stub Wails frontend bundle
# client/ui/main.go has //go:embed all:frontend/dist. The
# directory is produced by `pnpm run build` and is gitignored, so
# lint-only runs (no frontend toolchain) need a placeholder file
# for the embed pattern to match.
shell: bash
run: |
mkdir -p client/ui/frontend/dist
touch client/ui/frontend/dist/.embed-placeholder
- name: golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
with:

View File

@@ -20,30 +20,15 @@ jobs:
per_page: 100,
});
// Cover renamed .pb.go files in addition to plain edits.
// Renamed entries land under the new path with previous_filename
// pointing at the base-side name, so we read the base content
// from the old path when present.
const changedPbFiles = files
.filter(f => (f.status === 'modified' || f.status === 'renamed')
&& f.filename.endsWith('.pb.go'))
.map(f => ({
headPath: f.filename,
basePath: f.previous_filename || f.filename,
}));
if (changedPbFiles.length === 0) {
console.log('No modified or renamed .pb.go files to check');
const modifiedPbFiles = files.filter(
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
);
if (modifiedPbFiles.length === 0) {
console.log('No modified .pb.go files to check');
return;
}
// Matches the generator version headers protoc writes at the top
// of generated files:
// // protoc v3.21.12
// // protoc-gen-go v1.26.0
// // - protoc-gen-go-grpc v1.6.1 (grpc files prefix with "- ")
// The optional "- " prefix and the optional -gen-go / -gen-go-grpc
// suffixes keep the *_grpc.pb.go headers in scope.
const versionPattern = /^\s*\/\/\s+(?:-\s+)?protoc(?:-gen-go(?:-grpc)?)?\s+v[\d.]+/;
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const baseSha = context.payload.pull_request.base.sha;
const headSha = context.payload.pull_request.head.sha;
@@ -70,22 +55,20 @@ jobs:
}
const violations = [];
for (const file of changedPbFiles) {
for (const file of modifiedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.basePath, baseSha),
getVersionHeader(file.headPath, headSha),
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.headPath}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
violations.push({
file: file.basePath === file.headPath
? file.headPath
: `${file.basePath} → ${file.headPath}`,
file: file.filename,
base: base.lines,
head: head.lines,
});

View File

@@ -194,9 +194,9 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
run: goversioninfo -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
@@ -356,8 +356,18 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Set up pnpm
uses: pnpm/action-setup@v3
with:
version: 11
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
@@ -376,10 +386,16 @@ jobs:
echo "/tmp/llvm-mingw-20250709-ucrt-ubuntu-22.04-x86_64/bin" >> $GITHUB_PATH
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Install wails3 CLI
# Version derived from go.mod so the binding generator always matches
# the wails runtime the binary links against.
run: |
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
- name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
run: goversioninfo -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
run: goversioninfo -arm -64 -icon client/ui/build/windows/icon.ico -manifest client/ui/build/windows/wails.exe.manifest -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
@@ -447,6 +463,20 @@ jobs:
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '22'
- name: Set up pnpm
uses: pnpm/action-setup@v3
with:
version: 11
- name: Install wails3 CLI
# Version derived from go.mod so the binding generator always matches
# the wails runtime the binary links against.
run: |
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
@@ -534,23 +564,6 @@ jobs:
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
id: download-mesa3d
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
destination: ${{ env.downloadPath }}\mesa3d.7z
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
- name: Move opengl32.dll into dist (amd64 only)
if: matrix.arch == 'amd64'
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
@@ -573,6 +586,28 @@ jobs:
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Set up Go for wails3 CLI
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install wails3 CLI
# Version derived from go.mod so the bootstrapper payload always
# matches the wails runtime the binary links against.
shell: bash
run: |
WAILS_VERSION=$(go list -m -f '{{.Version}}' github.com/wailsapp/wails/v3)
go install github.com/wailsapp/wails/v3/cmd/wails3@$WAILS_VERSION
- name: Stage WebView2 bootstrapper for installers
# Both client/installer.nsis and client/netbird.wxs reference
# client/MicrosoftEdgeWebview2Setup.exe. wails3 writes it there.
# The signing pipeline (netbirdio/sign-pipelines) does the same
# step for release builds; this mirrors it for PR sanity testing.
shell: bash
run: wails3 generate webview2bootstrapper -dir client
- name: Build NSIS installer
shell: pwsh
env:

View File

@@ -27,7 +27,7 @@ jobs:
with:
go-version-file: "go.mod"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
run: sudo apt update && sudo apt install -y -q libgtk-4-dev libwebkitgtk-6.0-dev libsoup-3.0-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
with:

View File

@@ -114,6 +114,16 @@ linters:
- linters:
- staticcheck
text: "QF1012"
# client/ui/main.go uses //go:embed all:frontend/dist; the
# directory is populated by `pnpm build` in the release pipeline
# and missing at lint time, so the embed parses to "no matching
# files found" — surfaced by golangci-lint's typecheck pre-pass.
# Suppress just that one diagnostic; the rest of the package
# (services/, tray.go, grpc.go, ...) still gets linted normally.
- linters:
- typecheck
path: client/ui/main\.go
text: "pattern all:frontend/dist"
paths:
- third_party$
- builtin$

View File

@@ -1,6 +1,15 @@
version: 2
project_name: netbird-ui
before:
hooks:
# Bindings are gitignored; regenerate before the frontend build so
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
# build without them).
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
builds:
- id: netbird-ui
dir: client/ui
@@ -70,12 +79,15 @@ nfpms:
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/build/netbird.desktop
- src: client/ui/build/linux/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png
- src: client/ui/build/appicon.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
- libgtk-3-0
- libwebkit2gtk-4.1-0
- libayatana-appindicator3-1
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
@@ -89,12 +101,15 @@ nfpms:
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/build/netbird.desktop
- src: client/ui/build/linux/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png
- src: client/ui/build/appicon.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
- gtk3
- webkit2gtk4.1
- libayatana-appindicator-gtk3
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'

View File

@@ -1,6 +1,15 @@
version: 2
project_name: netbird-ui
before:
hooks:
# Bindings are gitignored; regenerate before the frontend build so
# the @wailsio/runtime Vite plugin can resolve them (vite refuses to
# build without them).
- sh -c 'cd client/ui && wails3 generate bindings -clean=true -ts'
- sh -c 'cd client/ui/frontend && pnpm install --frozen-lockfile && pnpm build'
builds:
- id: netbird-ui-darwin
dir: client/ui
@@ -20,8 +29,6 @@ builds:
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin

View File

@@ -22,11 +22,19 @@ import (
"github.com/netbirdio/netbird/util"
)
// extendSessionFlag drives the `netbird login --extend` flow: refresh the
// SSO session expiry on the management server without tearing down the
// tunnel. Mutually exclusive with setup-key login (a setup-key cannot
// refresh an SSO-tracked peer — see auth.errSetupKeyOnSSOExpiredPeer).
var extendSessionFlag bool
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
loginCmd.PersistentFlags().BoolVar(&extendSessionFlag, "extend", false,
"refresh the SSO session expiry without tearing down the tunnel (requires an active connection)")
}
var loginCmd = &cobra.Command{
@@ -61,6 +69,16 @@ var loginCmd = &cobra.Command{
return err
}
if extendSessionFlag {
if providedSetupKey != "" {
return fmt.Errorf("--extend cannot be combined with a setup key; setup keys can only enrol new peers")
}
if err := doExtendSession(ctx, cmd); err != nil {
return fmt.Errorf("extend session failed: %v", err)
}
return nil
}
// workaround to run without service
if util.FindFirstLogPath(logFiles) == "" {
if err := doForegroundLogin(ctx, cmd, providedSetupKey, activeProf); err != nil {
@@ -150,6 +168,65 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
return nil
}
// doExtendSession drives the daemon's RequestExtendAuthSession /
// WaitExtendAuthSession pair. The user is sent through a regular SSO flow
// (browser + verification URL) and the resulting JWT is forwarded to the
// management server's ExtendAuthSession RPC. The tunnel stays up
// throughout — no Down/Up, no network-map resync.
func doExtendSession(ctx context.Context, cmd *cobra.Command) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.RequestExtendAuthSessionRequest{}
// Pre-fill the IdP login hint from the active profile so the user
// doesn't have to retype their email. Best-effort: we still proceed
// without a hint if the lookup fails.
pm := profilemanager.NewProfileManager()
if active, perr := pm.GetActiveProfile(); perr == nil {
if profState, sperr := pm.GetProfileState(active.Name); sperr == nil && profState.Email != "" {
req.Hint = &profState.Email
}
}
startResp, err := client.RequestExtendAuthSession(ctx, req)
if err != nil {
return fmt.Errorf("start extend session: %v", err)
}
uri := startResp.GetVerificationURIComplete()
if uri == "" {
uri = startResp.GetVerificationURI()
}
openURL(cmd, uri, startResp.GetUserCode(), noBrowser, showQR)
waitResp, err := client.WaitExtendAuthSession(ctx, &proto.WaitExtendAuthSessionRequest{
DeviceCode: startResp.GetDeviceCode(),
UserCode: startResp.GetUserCode(),
})
if err != nil {
return fmt.Errorf("wait for extend session: %v", err)
}
if ts := waitResp.GetSessionExpiresAt(); ts.IsValid() && !ts.AsTime().IsZero() {
deadline := ts.AsTime().Local()
cmd.Printf("Session extended. New expiry: %s\n", deadline.Format("2006-01-02 15:04:05 MST"))
} else {
// Management reported the peer is not eligible (e.g. login
// expiration disabled on the account). Surface that fact
// instead of pretending the call succeeded.
cmd.Println("Session extension call completed, but the management server did not return a new deadline (peer may not be SSO-tracked or login expiration is disabled).")
}
return nil
}
func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) (*profilemanager.Profile, error) {
// switch profile if provided

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
@@ -117,6 +118,11 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var sessionExpiresAt time.Time
if ts := resp.GetSessionExpiresAt(); ts.IsValid() {
sessionExpiresAt = ts.AsTime().UTC()
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
DaemonVersion: resp.GetDaemonVersion(),
@@ -127,6 +133,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
IPsFilter: ipsFilterMap,
ConnectionTypeFilter: connectionTypeFilter,
ProfileName: profName,
SessionExpiresAt: sessionExpiresAt,
})
var statusOutputString string
switch {

View File

@@ -12,13 +12,7 @@ var (
Short: "Print the NetBird's client application version",
Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout())
out := version.NetbirdVersion()
if version.IsDevelopmentVersion(out) {
if commit := version.NetbirdCommit(); commit != "" {
out += "-" + commit
}
}
cmd.Println(out)
cmd.Println(version.NetbirdVersion())
},
}
)

View File

@@ -464,7 +464,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
_ = engine.RunHealthProbes(context.Background(), false)
}
}

View File

@@ -1,11 +0,0 @@
//go:build android || (!linux && !windows)
package firewall
import "github.com/netbirdio/netbird/client/firewall/uspfilter"
// interfaceAllower returns no allower: these platforms have no host firewall to
// open for the interface.
func interfaceAllower(IFaceMapper, uint16) uspfilter.InterfaceAllower {
return nil
}

View File

@@ -1,10 +0,0 @@
//go:build windows
package firewall
import "github.com/netbirdio/netbird/client/firewall/uspfilter"
// interfaceAllower returns the Windows netsh-based interface allower.
func interfaceAllower(iface IFaceMapper, _ uint16) uspfilter.InterfaceAllower {
return uspfilter.NewWindowsInterfaceAllower(iface)
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"runtime"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -19,11 +21,13 @@ func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.
}
// use userspace packet filtering firewall
return uspfilter.Create(uspfilter.Config{
IFace: iface,
DisableServerRoutes: disableServerRoutes,
FlowLogger: flowLogger,
MTU: mtu,
InterfaceAllower: interfaceAllower(iface, mtu),
})
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
err = fm.AllowNetbird()
if err != nil {
log.Warnf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}

View File

@@ -16,7 +16,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -30,107 +29,47 @@ const (
NFTABLES
)
// SkipNftablesEnv is the environment variable to skip nftables check
const SkipNftablesEnv = "NB_SKIP_NFTABLES_CHECK"
// errNoFirewallManager indicates no kernel firewall backend is present,
// as opposed to a backend that exists but failed to create or initialize.
var errNoFirewallManager = errors.New("no firewall manager found")
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// Userspace firewall without a native counterpart: routing is handled
// entirely in userspace. The interface is opened in the kernel's foreign
// filter chains via a table-less allower, except in netstack mode where no
// kernel interface exists.
if netstack.IsEnabled() || (iface.IsUserspaceBind() && forceUserspaceFirewall()) {
if netstack.IsEnabled() {
log.Info("netstack mode, using userspace firewall")
} else {
log.Info("forcing userspace firewall")
}
cfg := uspfilter.Config{
IFace: iface,
DisableServerRoutes: disableServerRoutes,
FlowLogger: flowLogger,
MTU: mtu,
InterfaceAllower: interfaceAllower(iface, mtu),
}
return uspfilter.Create(cfg)
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
log.Info("forcing userspace firewall")
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
fm, err := createNativeFirewall(iface, stateManager, mtu)
switch {
case err == nil && !iface.IsUserspaceBind():
// Nothing to do, fall through
case err == nil && iface.IsUserspaceBind():
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
case err != nil && !iface.IsUserspaceBind():
// Kernel cannot fall back to anything else, need to return error
return nil, err
case err != nil && iface.IsUserspaceBind():
// Fall back to the userspace packet filter if native is unavailable
logNativeFirewallUnavailable(err)
return uspfilter.Create(uspfilter.Config{
IFace: iface,
DisableServerRoutes: disableServerRoutes,
FlowLogger: flowLogger,
MTU: mtu,
InterfaceAllower: interfaceAllower(iface, mtu),
})
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() {
return fm, err
}
// Fall back to the userspace packet filter if native is unavailable
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
return fm, nil
}
// interfaceAllower selects how the userspace firewall opens the interface in
// foreign kernel chains: nftables when available (which also opens foreign nft
// tables), else iptables (the legacy fallback, filter INPUT only), else nil.
// firewalld trust is applied separately by the manager. Netstack has no kernel
// interface to open.
func interfaceAllower(iface IFaceMapper, mtu uint16) uspfilter.InterfaceAllower {
if netstack.IsEnabled() {
return nil
}
nftAllower, err := nbnftables.NewInterfaceAllower(iface, mtu)
if err == nil {
return nftAllower
}
log.Infof("no nftables interface allower: %v", err)
iptAllower, err := nbiptables.NewInterfaceAllower(iface)
if err == nil {
return iptAllower
}
log.Infof("no iptables interface allower: %v", err)
return nil
}
// logNativeFirewallUnavailable logs the fallback to userspace at info level
// when no kernel firewall backend exists, and at warn level otherwise.
func logNativeFirewallUnavailable(err error) {
if errors.Is(err, errNoFirewallManager) {
log.Infof("no native firewall backend available: %v. Proceeding with userspace", err)
} else {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, mtu uint16) (firewall.Manager, error) {
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
fm, err := createFW(iface, mtu)
if err != nil {
return nil, fmt.Errorf("create firewall: %w", err)
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
@@ -149,10 +88,29 @@ func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface, mtu)
default:
return nil, errNoFirewallManager
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
}
if errUsp != nil {
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
useIPTABLES := false
@@ -174,38 +132,35 @@ func check() FWType {
}
}
// Honor the skip env before probing nftables at all.
if os.Getenv(SkipNftablesEnv) != "true" {
nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil {
if !useIPTABLES {
nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
if !useIPTABLES {
return NFTABLES
}
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
return NFTABLES
}
}
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
return NFTABLES
}
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
}
@@ -221,21 +176,15 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
return err == nil
}
// forceUserspaceFirewall reports whether the userspace firewall is forced.
// NB_FORCE_USERSPACE_ROUTER is an alias: forcing userspace routing implies the
// userspace firewall, since the two are no longer separable.
func forceUserspaceFirewall() bool {
return envForceBool(EnvForceUserspaceFirewall) || envForceBool(uspfilter.EnvForceUserspaceRouter)
}
func envForceBool(name string) bool {
val := os.Getenv(name)
val := os.Getenv(EnvForceUserspaceFirewall)
if val == "" {
return false
}
force, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", name, err)
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
return false
}
return force

View File

@@ -0,0 +1,554 @@
package iptables
import (
"errors"
"fmt"
"net"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
)
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
return &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
if m.v6 && ipsetName != "" {
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
specs: specs,
v6: m.v6,
}}, nil
}
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
}
if err := m.createIPSet(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err)
}
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// Insert at the beginning of the chain (position 1)
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
} else {
err = m.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
return nil, err
}
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
mangleSpecs: mangleSpecs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
v6: m.v6,
}
m.updateState()
return []firewall.Rule{rule}, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
ip := net.ParseIP(r.ip)
if ip == nil {
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ipsetList.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
shouldDestroyIpset = true
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
}
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState()
return nil
}
func (m *aclManager) Reset() error {
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
m.updateState()
return nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["INPUT"] {
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
for _, rule := range m.entries["FORWARD"] {
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
}
if err := m.destroyIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
}
m.ipsetStore.deleteIpset(ipsetName)
}
return nil
}
func (m *aclManager) createDefaultChains() error {
// chain netbird-acl-input-rules
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
return err
}
for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified()
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
if ipsetName == "" {
return ""
}
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"
}
switch {
case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport" + actionSuffix
case sPort != nil:
return ipsetName + "-sport" + actionSuffix
case dPort != nil:
return ipsetName + "-dport" + actionSuffix
default:
return ipsetName + actionSuffix
}
}
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -1,352 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) createContainers() error {
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFwdIn, tableFilter},
{chainRTFwdOut, tableFilter},
{chainRTPre, tableMangle},
{chainRTNAT, tableNat},
{chainRTRdr, tableNat},
{chainRTMSSClamp, tableMangle},
} {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFwdIn); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFwdOut); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil
}
func (r *family) addJumpRules() error {
// Jump to nat chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPostrouting, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %w", err)
}
r.rules[jumpNATPost] = natRule
// Jump to mangle prerouting chain
preRule := []string{"-j", chainRTPre}
if err := r.iptablesClient.Insert(tableMangle, chainPrerouting, 1, preRule...); err != nil {
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
}
r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRdr}
if err := r.iptablesClient.Insert(tableNat, chainPrerouting, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %w", err)
}
r.rules[jumpNATPre] = rdrRule
return nil
}
func (r *family) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPrerouting, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPostrouting, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
// seedInitialEntries adds default rules to the entries map. Rules are
// inserted at position 1, so the order here is reversed.
//
// Existing FORWARD policy decides outbound traffic towards our
// interface. If FORWARD policy is "drop", we add an
// established/related rule to allow return traffic for inbound rules.
func (r *family) seedInitialEntries() {
established := getConntrackEstablished()
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", chainACLInput})
r.appendToEntries(chainInput, append([]string{"-i", r.wgIface.Name()}, established...))
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
r.appendToEntries(chainForward, []string{"-o", r.wgIface.Name(), "-j", chainRTFwdOut})
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", chainRTFwdIn})
// Mangle FORWARD guard: when external DNAT redirects traffic from
// the wg interface, it traverses FORWARD instead of INPUT,
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
// inserted above ours. Mangle runs before filter, so these guard
// rules enforce the ACL mark check where it cannot be overridden.
r.appendToEntries(mangleForwardKey, []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
r.appendToEntries(mangleForwardKey, []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (r *family) seedInitialOptionalEntries() {
r.optionalEntries[chainForward] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
}
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
r.entries[chain] = append(r.entries[chain], spec)
}
func (r *family) createDefaultChains() error {
if err := r.iptablesClient.NewChain(tableName, chainACLInput); err != nil {
return fmt.Errorf("create %s chain: %w", chainACLInput, err)
}
for chain, rules := range r.entries {
// mangle FORWARD guard rules are handled separately below
if chain == mangleForwardKey {
continue
}
for _, rule := range rules {
if err := r.iptablesClient.InsertUnique(tableName, string(chain), 1, rule...); err != nil {
return fmt.Errorf("insert jump rule into %s: %w", chain, err)
}
}
}
for chain, entries := range r.optionalEntries {
for _, entry := range entries {
if err := r.iptablesClient.InsertUnique(tableName, string(chain), entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
r.entries[chain] = append(r.entries[chain], entry.spec)
}
}
clear(r.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range r.entries[mangleForwardKey] {
if err := r.iptablesClient.AppendUnique(tableMangle, chainForward, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
func (r *family) cleanUpDefaultForwardRules() error {
var merr *multierror.Error
// cleanJumpRules removes the OUTPUT jump to NETBIRD-NAT-OUTPUT among
// the others, so the chain below deletes cleanly instead of failing
// with "device or resource busy".
if err := r.cleanJumpRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clean jump rules: %w", err))
}
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFwdIn, tableFilter},
{chainRTFwdOut, tableFilter},
{chainRTPre, tableMangle},
{chainRTNAT, tableNat},
{chainRTRdr, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSClamp, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
continue
}
if ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
}
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanJumpRules() error {
// locations maps each tracked jump rule to the built-in table and
// chain it was inserted into.
locations := map[firewall.RuleID]struct{ table, chain string }{
jumpNATPost: {tableNat, chainPostrouting},
jumpManglePre: {tableMangle, chainPrerouting},
jumpNATPre: {tableNat, chainPrerouting},
jumpMSSClamp: {tableMangle, chainForward},
jumpNATOutput: {tableNat, chainOutput},
}
var merr *multierror.Error
for ruleID, loc := range locations {
rule, exists := r.rules[ruleID]
if !exists {
continue
}
if err := r.iptablesClient.DeleteIfExists(loc.table, loc.chain, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule from chain %s in table %s: %w", loc.chain, loc.table, err))
continue
}
delete(r.rules, ruleID)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanAclChains() error {
var merr *multierror.Error
if err := r.cleanInputAclChain(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanPreroutingEntries(); err != nil {
merr = multierror.Append(merr, err)
}
for _, rule := range r.entries[mangleForwardKey] {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete mangle %s guard rule %v: %w", chainForward, rule, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanInputAclChain() error {
ok, err := r.iptablesClient.ChainExists(tableName, chainACLInput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainACLInput, err)
}
if !ok {
return nil
}
var merr *multierror.Error
for _, rule := range r.entries[chainInput] {
if err := r.iptablesClient.DeleteIfExists(tableName, chainInput, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainInput, rule, err))
}
}
for _, rule := range r.entries[chainForward] {
if err := r.iptablesClient.DeleteIfExists(tableName, chainForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainForward, rule, err))
}
}
if err := r.iptablesClient.ClearAndDeleteChain(tableName, chainACLInput); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clear and delete %s chain: %w", chainACLInput, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanPreroutingEntries() error {
ok, err := r.iptablesClient.ChainExists(tableMangle, chainPrerouting)
if err != nil {
return fmt.Errorf("check chain %s in %s: %w", chainPrerouting, tableMangle, err)
}
if !ok {
return nil
}
var merr *multierror.Error
for _, rule := range r.entries[chainPrerouting] {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainPrerouting, rule, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPostrouting, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,285 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
ruleID := rule.ID()
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[firewall.RuleID]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleID+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRdr,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleID+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleID+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFwdOut,
rule: forwardRule,
}
// Request forwarding once the rule is about to be installed, releasing
// it if installation fails so the refcount tracks the real rules.
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
r.releaseForwarding()
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
ruleID := rule.ID()
var merr *multierror.Error
var found bool
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
found = true
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleID+dnatSuffix)
}
if snatRule, exists := r.rules[ruleID+snatSuffix]; exists {
found = true
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleID+snatSuffix)
}
if fwdRule, exists := r.rules[ruleID+fwdSuffix]; exists {
found = true
if err := r.iptablesClient.Delete(tableFilter, chainRTFwdOut, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleID+fwdSuffix)
}
r.updateState()
// Release once, only if the rule was present and removed.
if merr == nil && found {
r.releaseForwarding()
}
return nberrors.FormatErrorOrNil(merr)
}
// releaseForwarding drops one IP forwarding reference, logging any error.
func (r *family) releaseForwarding() {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("release IP forwarding: %v", err)
}
}
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
info := ruleInfo{
table: tableNat,
chain: chainRTRdr,
rule: dnatRule,
}
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = info.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *family) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNATOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, chainOutput, 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNATOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}

View File

@@ -1,253 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"maps"
"net/netip"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableName = tableFilter
tableNat = "nat"
tableMangle = "mangle"
// chainACLInput is the peer ACL chain that holds installed
// peer-filtering rules.
chainACLInput = "NETBIRD-ACL-INPUT"
// mangleForwardKey is the entries map key for mangle FORWARD guard
// rules that prevent external DNAT from bypassing ACL rules.
mangleForwardKey chainKey = "MANGLE-FORWARD"
chainInput = "INPUT"
chainPostrouting = "POSTROUTING"
chainPrerouting = "PREROUTING"
chainForward = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFwdIn = "NETBIRD-RT-FWD-IN"
chainRTFwdOut = "NETBIRD-RT-FWD-OUT"
chainRTPre = "NETBIRD-RT-PRE"
chainRTRdr = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSClamp = "NETBIRD-RT-MSSCLAMP"
jumpManglePre = "jump-mangle-pre"
jumpNATPre = "jump-nat-pre"
jumpNATPost = "jump-nat-post"
jumpNATOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
dnatSuffix firewall.RuleID = "_dnat"
snatSuffix firewall.RuleID = "_snat"
fwdSuffix firewall.RuleID = "_fwd"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
)
type ruleInfo struct {
chain string
table string
rule []string
}
type routeRules map[firewall.RuleID][]string
// ruleSpec is a single iptables rule expressed as its argument list
// (e.g. {"-i", "wg0", "-j", "DROP"}).
type ruleSpec []string
// chainKey identifies the chain a seeded entry belongs to. It holds
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
// synthetic mangleForwardKey bucket for the mangle FORWARD guard rules.
type chainKey string
// aclEntries maps a chain to the rules seeded into it to jump into or
// guard the netbird ACL chains.
type aclEntries map[chainKey][]ruleSpec
type entry struct {
spec ruleSpec
position int
}
// ipsetCounter is the shared hash:net refcounter used by peer and
// route ACLs alike. The ipset library does not support comments, so
// the key is just the set name (string).
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
// family holds the per-address-family iptables state. One instance
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
// single family; the top-level Manager owns one for v4 and another
// for v6.
type family struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
v6 bool
// Peer ACL chain bookkeeping.
entries aclEntries
optionalEntries map[chainKey][]entry
// filters holds peer + route filter rules keyed by content hash.
// AddFilterRule writes here; DeleteFilterRule looks up by id.
filters map[nbid.RuleID]*Rule
ipsetCounter *ipsetCounter
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
// plumbing that isn't a filter rule).
rules routeRules
// Routing / NAT.
legacyManagement bool
mtu uint16
ipFwdState *ipfwdstate.IPForwardingState
stateManager *statemanager.Manager
}
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
r := &family{
iptablesClient: iptablesClient,
wgIface: wgIface,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
entries: make(aclEntries),
optionalEntries: make(map[chainKey][]entry),
filters: make(map[nbid.RuleID]*Rule),
rules: make(routeRules),
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
func(name string, sources []netip.Prefix) (struct{}, error) {
return struct{}{}, r.createIpSet(name, sources)
},
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)
return r, nil
}
// init wires the family to the state manager and installs both the
// route ACL containers and the peer ACL chain skeleton.
func (r *family) init(stateManager *statemanager.Manager) error {
r.stateManager = stateManager
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.seedInitialEntries()
r.seedInitialOptionalEntries()
if err := r.cleanAclChains(); err != nil {
return fmt.Errorf("clean acl chains: %w", err)
}
if err := r.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
r.updateState()
return nil
}
// Reset tears down all firewall state owned by this family. ACL
// chain cleanup runs before route-chain cleanup because the route
// chains are still referenced by FORWARD jumps installed during
// seedInitialEntries; deleting them first would trip EBUSY.
func (r *family) Reset() error {
var merr *multierror.Error
if err := r.cleanAclChains(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
clear(r.rules)
clear(r.filters)
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
// Clone the rule maps so the persisted state holds a private snapshot.
// The live maps keep being mutated by subsequent rule operations while
// the state manager marshals the state from its periodic-save goroutine.
// Sharing the maps by reference races the two and aborts the process with
// a concurrent map iteration and write. The ipset counter guards itself
// during marshaling, so it can be shared directly.
if r.v6 {
currentState.RouteRules6 = maps.Clone(r.rules)
currentState.RouteIPsetCounter6 = r.ipsetCounter
currentState.ACLEntries6 = maps.Clone(r.entries)
} else {
currentState.RouteRules = maps.Clone(r.rules)
currentState.RouteIPsetCounter = r.ipsetCounter
currentState.ACLEntries = maps.Clone(r.entries)
}
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}

View File

@@ -1,341 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nbnet "github.com/netbirdio/netbird/client/net"
)
// AddFilterRule installs a packet-filtering rule. With destination
// empty, the rule goes to the peer ACL input chain plus a paired
// mangle PREROUTING rule for the redirect mark. With destination set
// (prefix or named set), it goes to the route ACL forward chain.
// Multi-source rules collapse to one iptables rule via the shared
// hash:net ipset.
func (r *family) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existing, ok := r.filters[ruleID]; ok {
return existing, nil
}
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
if err != nil {
return nil, fmt.Errorf("apply source match: %w", err)
}
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
if err != nil {
r.dropSourceMatch(srcMatch)
return nil, err
}
r.filters[ruleID] = rule
r.updateState()
return rule, nil
}
func (r *family) hasRule(id nbid.RuleID) bool {
_, ok := r.filters[id]
return ok
}
// hasDNATRule reports whether this family owns the DNAT rule set for
// the given user id. DNAT rules live in r.rules under the well-known
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
// to pick the right family.
func (r *family) hasDNATRule(id firewall.RuleID) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
// DeleteFilterRule removes a previously installed filter rule. The
// rule's stored chain/table identify where to delete from; source set
// references are recovered from the spec via findSets and dropped
// from the shared ipset counter.
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
ruleID := rule.ID()
pr, ok := r.filters[ruleID]
if !ok {
log.Debugf("filter rule %s not found", ruleID)
return nil
}
// DeleteIfExists keeps both deletes idempotent so a retry after a
// partial failure does not error on the half that was already removed.
var merr *multierror.Error
if err := r.iptablesClient.DeleteIfExists(tableFilter, pr.chain, pr.specs...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule from %s: %w", pr.chain, err))
}
if pr.mangleSpecs != nil {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, pr.mangleSpecs...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete mangle rule: %w", err))
}
}
if merr != nil {
// Leave the rule tracked so the caller retries the remaining half.
return nberrors.FormatErrorOrNil(merr)
}
r.dropSourceMatch(pr.specs)
delete(r.filters, ruleID)
r.updateState()
return nil
}
// findSets scans an iptables rule spec for "-m set --match-set <name>
// <dir>" fragments and returns the named sets in occurrence order.
// Used at delete time to drop ipsetCounter references.
func findSets(rule []string) []string {
var sets []string
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
sets = append(sets, rule[i+3])
}
}
return sets
}
// sourceNetwork classifies a source-prefix list into the firewall.Network
// shape the rest of the spec-builder consumes: empty for match-any, a
// single prefix inline, or an ipset for multiple sources.
func sourceNetwork(sources []netip.Prefix) firewall.Network {
switch {
case len(sources) == 0:
return firewall.Network{}
case len(sources) == 1 && sources[0].Bits() == 0:
return firewall.Network{}
case len(sources) == 1:
return firewall.Network{Prefix: sources[0]}
default:
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
}
}
// applySourceMatch returns the iptables match fragment for the rule's
// source. For a Set it increments the shared ipset's refcount; for a
// Prefix it emits a direct -s match; for the wildcard it returns nil.
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
switch {
case network.IsSet():
if r.ipsetCounter == nil {
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
}
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
}
return []string{"-m", "set", matchSet, name, "src"}, nil
case network.IsPrefix():
return []string{"-s", network.Prefix.String()}, nil
default:
return nil, nil
}
}
// dropSourceMatch undoes whatever applySourceMatch reserved. Safe to
// call when the spec is empty or holds only inline matchers. Decrement
// errors are logged but not returned: the filter rule has already been
// deleted at that point and we don't want to leak the deletion.
func (r *family) dropSourceMatch(srcMatch []string) {
if r.ipsetCounter == nil {
return
}
for _, name := range findSets(srcMatch) {
if _, err := r.ipsetCounter.Decrement(name); err != nil {
log.Errorf("rollback ipset decrement %s: %v", name, err)
}
}
}
// decrementSetCounter drops ipset references owned by a raw rule spec
// stored in r.rules (NAT / legacy route entries). It returns an error
// aggregate so the caller surfaces decrement failures.
func (r *family) decrementSetCounter(rule []string) error {
if r.ipsetCounter == nil {
return nil
}
var merr *multierror.Error
for _, name := range findSets(rule) {
if _, err := r.ipsetCounter.Decrement(name); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// installFilterRule assembles and writes one iptables filter-chain
// rule. With destination empty the rule lands in the peer ACL input
// chain and a paired mangle PREROUTING rule is added for the redirect
// mark. With destination set the rule lands in the route ACL forward
// chain and there is no mangle pairing.
func (r *family) installFilterRule(
ruleID nbid.RuleID,
srcMatch []string,
destination firewall.Network,
protocol firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (*Rule, error) {
isRoute := !destination.IsZero()
proto := protoForFamily(protocol, r.v6)
specs := slices.Clone(srcMatch)
var destExp []string
if isRoute {
var err error
destExp, err = r.applyNetwork("-d", destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
}
specs = append(specs, destExp...)
}
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
var mangleSpecs []string
if !isRoute {
mangleSpecs = slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", r.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
}
specs = append(specs, "-j", actionToStr(action))
chain := chainACLInput
if isRoute {
chain = chainRTFwdIn
}
// Peer ACL drops are inserted at position 1 so they precede the
// chain's catch-all; route ACL drops are inserted at position 2
// to sit immediately after the established/related accept rule.
var err error
if action == firewall.ActionDrop {
pos := 1
if isRoute {
pos = 2
}
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
} else {
err = r.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
r.dropSourceMatch(destExp)
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
}
// The mangle redirect-mark rule is best effort: the filter rule itself
// is what enforces the ACL, so a mangle failure must not undo it. Drop
// the spec so teardown does not try to remove a rule that was not added.
if mangleSpecs != nil {
if err := r.iptablesClient.Append(tableMangle, chainRTPre, mangleSpecs...); err != nil {
log.Errorf("add mangle rule: %v", err)
mangleSpecs = nil
}
}
return &Rule{
id: ruleID,
specs: specs,
mangleSpecs: mangleSpecs,
chain: chain,
v6: r.v6,
}, nil
}
// applyNetwork resolves a firewall.Network into the iptables match
// fragment for the given direction flag (-s or -d). Set networks
// increment the shared ipset refcount; prefixes emit a direct match;
// an empty network returns no spec ("match any").
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, name, direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
// filterMatchSpecs returns the proto/port match fragment for a
// filtering rule. The source match (-s or -m set) is built by the
// caller and prepended.
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil
}
if port.IsRange && len(port.Values) == 2 {
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
}
if len(port.Values) > 1 {
portList := make([]string, len(port.Values))
for i, p := range port.Values {
portList[i] = strconv.Itoa(int(p))
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
}
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}

View File

@@ -1,93 +0,0 @@
package iptables
import (
"fmt"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// InterfaceAllower opens the NetBird interface on the iptables filter INPUT
// chain so the host firewall doesn't drop traffic the userspace firewall
// handles. It is the fallback used when nftables is unavailable (an
// iptables-legacy host).
//
// It opens INPUT only: the userspace router never forwards in the kernel.
// firewalld trust is handled by the uspfilter manager, not here.
type InterfaceAllower struct {
ifaceName string
ipt4 *iptables.IPTables
// ipt6 is nil when the interface has no IPv6 overlay address.
ipt6 *iptables.IPTables
}
// NewInterfaceAllower builds an iptables allower for the interface. It returns
// an error when iptables is unavailable, so the caller can fall back to
// firewalld trust.
func NewInterfaceAllower(wgIface iFaceMapper) (*InterfaceAllower, error) {
ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("iptables not available: %w", err)
}
if _, err := ipt4.ListChains(tableFilter); err != nil {
return nil, fmt.Errorf("iptables filter table not available: %w", err)
}
a := &InterfaceAllower{ifaceName: wgIface.Name(), ipt4: ipt4}
// Missing v6 must not break the v4 path: open v4 only and continue.
if wgIface.Address().HasIPv6() {
ipt6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("ip6tables not available, opening interface on v4 only: %v", err)
} else if _, err := ipt6.ListChains(tableFilter); err != nil {
log.Warnf("ip6tables filter table not available, opening interface on v4 only: %v", err)
} else {
a.ipt6 = ipt6
}
}
return a, nil
}
// Apply inserts the interface accept rule on the filter INPUT chain. It removes
// any stale rule first so an unclean exit (e.g. SIGKILL, where Close never ran)
// is recovered deterministically rather than accumulating duplicates.
func (a *InterfaceAllower) Apply() error {
var merr *multierror.Error
for _, ipt := range a.clients() {
if err := ipt.DeleteIfExists(tableFilter, chainInput, a.inputRule()...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clean stale interface accept rule: %w", err))
}
if err := ipt.Insert(tableFilter, chainInput, 1, a.inputRule()...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add interface accept rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// Close removes the interface accept rule.
func (a *InterfaceAllower) Close() error {
var merr *multierror.Error
for _, ipt := range a.clients() {
if err := ipt.DeleteIfExists(tableFilter, chainInput, a.inputRule()...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove interface accept rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (a *InterfaceAllower) inputRule() []string {
return []string{"-i", a.ifaceName, "-j", "ACCEPT"}
}
func (a *InterfaceAllower) clients() []*iptables.IPTables {
clients := []*iptables.IPTables{a.ipt4}
if a.ipt6 != nil {
clients = append(clients, a.ipt6)
}
return clients
}

View File

@@ -1,104 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"github.com/hashicorp/go-multierror"
"github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
if err := r.createIPSet(setName); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
// The refcounter records nothing when this callback errors,
// so destroy the set or it leaks in the kernel. A partial
// source set would also fail-open for deny rules, so the
// rule must fail rather than install with a missing source.
if derr := r.destroyIPSet(setName); derr != nil {
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
}
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return nil
}
func (r *family) deleteIpSet(setName string) error {
if err := r.destroyIPSet(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
log.Debugf("deleted unused ipset %s", setName)
return nil
}
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error
for _, prefix := range prefixes {
if err := r.addPrefixToIPSet(name, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", name, prefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *family) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *family) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -3,6 +3,7 @@ package iptables
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
@@ -17,21 +18,25 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager of iptables firewall. Per-family state (peer ACLs, route
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
// by family and provides the public firewall.Manager surface.
type resetter interface {
Reset() error
}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
wgIface iFaceMapper
ipv4Client *iptables.IPTables
family4 *family
aclMgr *aclManager
router *router
rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
family6 *family
aclMgr6 *aclManager
router6 *router
}
// iFaceMapper defines subset methods of interface required for manager
@@ -52,9 +57,14 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
m.router, err = newRouter(iptablesClient, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create family: %w", err)
return nil, fmt.Errorf("create router: %w", err)
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
}
if wgIface.Address().HasIPv6() {
@@ -71,18 +81,21 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
if err != nil {
return fmt.Errorf("init ip6tables: %w", err)
}
m.ipv6Client = ip6Client
family6, err := newFamily(ip6Client, wgIface, mtu)
m.router6, err = newRouter(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 family: %w", err)
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 family, since
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
family6.ipFwdState = m.family4.ipFwdState
m.router6.ipFwdState = m.router.ipFwdState
m.ipv6Client = ip6Client
m.family6 = family6
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
@@ -96,7 +109,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.family4.mtu,
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)
@@ -128,24 +141,31 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return nil
}
// initChains initializes the per-family firewall state for both
// address families, rolling back on failure.
// initChains initializes router and ACL chains for both address families,
// rolling back on failure.
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
type initStep struct {
name string
r *family
init func(*statemanager.Manager) error
mgr resetter
}
steps := []initStep{{"v4", m.family4}}
steps := []initStep{
{"router", m.router.init, m.router},
{"acl manager", m.aclMgr.init, m.aclMgr},
}
if m.hasIPv6() {
steps = append(steps, initStep{"v6", m.family6})
steps = append(steps,
initStep{"v6 router", m.router6.init, m.router6},
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
)
}
var initialized []initStep
for _, s := range steps {
if err := s.r.init(stateManager); err != nil {
if err := s.init(stateManager); err != nil {
for i := len(initialized) - 1; i >= 0; i-- {
if rerr := initialized[i].r.Reset(); rerr != nil {
if rerr := initialized[i].mgr.Reset(); rerr != nil {
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
}
}
@@ -156,50 +176,84 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
return nil
}
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
// docs for destination semantics. Sources are a single address family;
// the rule is dispatched to the matching v4 / v6 backend.
func (m *Manager) AddFilterRule(
// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
ipsetName string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
fam := m.family4
if isIPv6Rule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
}
fam = m.family6
if ip.To4() != nil {
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
// DeleteFilterRule removes a rule previously added via AddFilterRule.
// The rule is looked up by id in each family's filter cache.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
id := rule.ID()
if m.family4.hasRule(id) {
return m.family4.DeleteFilterRule(rule)
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
if m.hasIPv6() && m.family6.hasRule(id) {
return m.family6.DeleteFilterRule(rule)
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
log.Debugf("filter rule %s not found in any family", id)
return nil
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
}
return m.aclMgr.DeletePeerRule(rule)
}
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
}
return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -218,10 +272,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddNatRule(pair)
return m.router6.AddNatRule(pair)
}
if err := m.family4.AddNatRule(pair); err != nil {
if err := m.router.AddNatRule(pair); err != nil {
return err
}
@@ -230,7 +284,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.family6.AddNatRule(v6Pair); err != nil {
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
@@ -246,18 +300,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return nil
}
return m.family6.RemoveNatRule(pair)
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.family4.RemoveNatRule(pair); err != nil {
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
@@ -266,11 +320,11 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.family6, isLegacy)
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
}
@@ -287,13 +341,19 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
}
if m.hasIPv6() {
if err := m.family6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
}
}
if err := m.family4.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
}
// Appending to merr intentionally blocks DeleteState below so ShutdownState
@@ -312,6 +372,27 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
var merr *multierror.Error
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
}
if m.hasIPv6() {
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
}
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nberrors.FormatErrorOrNil(merr)
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@@ -321,14 +402,14 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
@@ -343,9 +424,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddDNATRule(rule)
return m.router6.AddDNATRule(rule)
}
return m.family4.AddDNATRule(rule)
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
@@ -353,10 +434,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
return m.family6.DeleteDNATRule(rule)
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
}
return m.family4.DeleteDNATRule(rule)
return m.router.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
@@ -373,12 +454,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
}
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
@@ -395,9 +476,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
@@ -409,9 +490,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
@@ -423,9 +504,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
@@ -437,14 +518,14 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOutput = "OUTPUT"
chainOUTPUT = "OUTPUT"
tableRaw = "raw"
)
@@ -519,15 +600,15 @@ func (m *Manager) initNoTrackChain() error {
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOutput, 1, jumpRule...); err != nil {
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPrerouting, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); delErr != nil {
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
@@ -554,11 +635,11 @@ func (m *Manager) cleanupNoTrackChain() error {
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); err != nil {
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPrerouting, jumpRule...); err != nil {
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
@@ -573,13 +654,3 @@ func (m *Manager) cleanupNoTrackChain() error {
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}
// isIPv6Rule reports whether the rule belongs to the IPv6 family, from
// the destination prefix when set, otherwise from the (single-family)
// sources.
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}

View File

@@ -1,5 +1,3 @@
//go:build integration && !android
package iptables
import (
@@ -67,39 +65,46 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 fw.Rule
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
rr := rule2.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
for _, r := range rule2 {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("delete second rule", func(t *testing.T) {
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainACLInput)
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainACLInput)
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
}
})
}
@@ -121,13 +126,15 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{22}}
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
require.NoError(t, err, "failed to add deny rule")
require.NotNil(t, rule, "deny rule should not be nil")
require.NotEmpty(t, rule, "deny rule should not be empty")
// Verify the rule was added by checking iptables
rr := rule.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
for _, r := range rule {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("deny rule precedence test", func(t *testing.T) {
@@ -135,40 +142,36 @@ func TestIptablesManagerDenyRules(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
// Add accept rule first
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for same IP/port - this should take precedence
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
// Inspect the actual iptables rules to verify deny rule comes before accept rule
rules, err := ipv4Client.List("filter", chainACLInput)
rules, err := ipv4Client.List("filter", chainNameInputRules)
require.NoError(t, err, "failed to list iptables rules")
// Debug: print all rules
t.Logf("All iptables rules in chain %s:", chainACLInput)
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
for i, rule := range rules {
t.Logf(" [%d] %s", i, rule)
}
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
// match. Match on that shape instead of the legacy
// per-(action,port) ipset names ("deny-http"/"accept-http")
// that this test predates.
srcMatch := fmt.Sprintf("-s %s/32", ip)
var denyRuleIndex, acceptRuleIndex = -1, -1
for i, rule := range rules {
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
continue
}
if strings.Contains(rule, "-j DROP") {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
denyRuleIndex = i
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
denyRuleIndex = i
}
}
if strings.Contains(rule, "-j ACCEPT") {
if strings.Contains(rule, "ACCEPT") {
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
acceptRuleIndex = i
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
acceptRuleIndex = i
}
}
}
@@ -193,6 +196,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
},
}
// just check on the local interface
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -206,39 +210,27 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 fw.Rule
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
require.NotNil(t, rule2)
require.Contains(t, rule2.(*Rule).specs, "-s",
"single-source rule should use direct -s match, not an ipset")
require.Empty(t, findSets(rule2.(*Rule).specs),
"single-source rule should not allocate a shared ipset")
})
t.Run("delete single-source rule", func(t *testing.T) {
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
})
t.Run("multi-source uses shared ipset", func(t *testing.T) {
sources := []netip.Prefix{
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
}
port := &fw.Port{Values: []uint16{8080}}
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add multi-source rule")
require.NotNil(t, multi, "multi-source rule must produce one iptables rule")
sets := findSets(multi.(*Rule).specs)
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
})
require.NoError(t, manager.DeleteFilterRule(multi))
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
}
})
t.Run("reset check", func(t *testing.T) {
@@ -289,7 +281,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//go:build integration && !android
//go:build !android
package iptables
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
@@ -52,12 +52,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 11. MSS clamping rule for outbound traffic
require.Len(t, manager.rules, 11, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPostrouting, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPostrouting)
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPrerouting, "-j", chainRTPre)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPrerouting)
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
@@ -95,7 +95,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted")
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -106,8 +106,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "marking rule should be created")
foundRule, found := manager.rules[natRuleKey]
@@ -121,7 +121,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -132,8 +132,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
@@ -170,7 +170,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -181,8 +181,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
_, found := manager.rules[natRuleKey]
@@ -190,7 +190,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -201,8 +201,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family manager")
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset family")
require.NoError(t, err, "Failed to reset router")
}()
tests := []struct {
@@ -334,30 +334,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddFilterRule failed")
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
stored, ok := r.filters[ruleKey.ID()]
require.True(t, ok, "rule not stored in filters")
t.Logf("Internal rule: %v", stored.specs)
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
exists, err := iptablesClient.Exists(tableFilter, chainRTFwdIn, stored.specs...)
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
if tt.expectSet {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content
params := routeFilteringRuleParams{
Source: source,
Destination: firewall.Network{Prefix: tt.destination},
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
}
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
expectedRule, err = r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec with set")
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
})
}
}
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
@@ -398,7 +430,7 @@ func TestFindSetNameInRule(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := findSets(tc.rule)
result := r.findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)

View File

@@ -1,265 +0,0 @@
//go:build !android
package iptables
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
r.updateState()
return nil
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
r.updateState()
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err
}
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", "ACCEPT"}
if err := r.iptablesClient.Append(tableFilter, chainRTFwdIn, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
r.rules[ruleID] = rule
return nil
}
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
}
return nil
}
// GetLegacyManagement returns the current legacy management mode
func (r *family) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *family) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *family) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
} else {
delete(r.rules, k)
}
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", "MASQUERADE",
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %w", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", "MASQUERADE",
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %w", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
func (r *family) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSClamp,
}
if err := r.iptablesClient.Insert(tableMangle, chainForward, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSClamp, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *family) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
ruleID := firewall.RuleID("established-" + chain)
r.rules[ruleID] = establishedRule
return nil
}
func (r *family) addNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.NatFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
}
delete(r.rules, ruleID)
}
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
)
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPre, 1, rule...); err != nil {
r.dropSourceMatch(rule)
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
}
r.rules[ruleID] = rule
return nil
}
func (r *family) removeNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.NatFormat)
if rule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
}
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else {
log.Debugf("marking rule %s not found", ruleID)
}
return nil
}

View File

@@ -1,20 +1,18 @@
package iptables
import "github.com/netbirdio/netbird/client/firewall/manager"
// Rule to handle management of rules. Source set membership (when the
// rule was built against a shared hash:net ipset) is encoded in specs;
// DeleteFilterRule recovers it via findSets so the refcounter can drop
// the right reference.
// Rule to handle management of rules
type Rule struct {
id manager.RuleID
ruleID string
ipsetName string
specs []string
mangleSpecs []string
ip string
chain string
v6 bool
}
// ID returns the rule id
func (r *Rule) ID() manager.RuleID {
return r.id
// GetRuleID returns the rule id
func (r *Rule) ID() string {
return r.ruleID
}

View File

@@ -0,0 +1,103 @@
package iptables
import "encoding/json"
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) *ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return &ipList{
ips: ips,
}
}
func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil
}
type ipsetStore struct {
ipsets map[string]*ipList
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]*ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) ipsetNames() []string {
names := make([]string, 0, len(s.ipsets))
for name := range s.ipsets {
names = append(names, name)
}
return names
}
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil
}

View File

@@ -29,13 +29,17 @@ type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
// IPv6 counterparts
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
}
func (s *ShutdownState) Name() string {
@@ -53,14 +57,17 @@ func (s *ShutdownState) Cleanup() error {
}
if s.RouteRules != nil {
ipt.family4.rules = s.RouteRules
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.family4.entries = s.ACLEntries
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
// Clean up v6 state even if the current run has no IPv6.
@@ -72,13 +79,16 @@ func (s *ShutdownState) Cleanup() error {
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.family6.rules = s.RouteRules6
ipt.router6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.family6.entries = s.ACLEntries6
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
}
}

View File

@@ -1,27 +0,0 @@
//go:build integration && !android
package iptables
import (
"fmt"
"net"
"net/netip"
)
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
}
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -3,6 +3,7 @@ package manager
import (
"errors"
"fmt"
"net"
"net/netip"
"sort"
@@ -15,12 +16,6 @@ import (
// method but the IPv6 firewall components were not initialized.
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
// ErrNoSources is returned when AddFilterRule is called with an empty
// source list. "Match any source" must be expressed explicitly with a
// /0 prefix; an empty list is a caller error and is rejected rather
// than silently widening the rule to every source.
var ErrNoSources = errors.New("rule has no sources")
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
@@ -28,18 +23,13 @@ const (
NatFormat = "netbird-nat-%s-%t"
)
// RuleID identifies a firewall rule. It is a typed string so the
// compiler catches accidental mixing with arbitrary string keys. It is
// only an identifier and does not implement Rule.
type RuleID string
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// ID returns the rule id
ID() RuleID
ID() string
}
// RuleDirection is the traffic direction which a rule is applied
@@ -101,13 +91,6 @@ func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// IsZero returns true if the network designates no destination, i.e. it
// is the zero value. A zero Network is the peer-rule sentinel; a non-zero
// one carries a prefix or set destination.
func (d Network) IsZero() bool {
return !d.IsPrefix() && !d.IsSet()
}
// Manager is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
@@ -115,42 +98,46 @@ func (d Network) IsZero() bool {
type Manager interface {
Init(stateManager *statemanager.Manager) error
// AddFilterRule adds a packet-filtering rule to the firewall.
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall
//
// If destination is the zero Network, the rule applies to traffic
// inbound to this node, i.e. peer ACL semantics, installed in
// the kernel's input chain. If destination is set (prefix or
// set), the rule applies to forwarded traffic with that
// destination, route ACL semantics, installed in the forward
// chain.
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
//
// sources must be a single address family; the caller splits mixed
// families and calls once per family. "Match any source" must be
// expressed with an explicit /0 prefix; an empty sources list is
// rejected with ErrNoSources so a zeroed list can never widen a
// rule to every source.
//
// Note: callers should call Flush() after adding rules.
AddFilterRule(
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering(
id []byte,
sources []netip.Prefix,
destination Network,
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
ipsetName string,
) ([]Rule, error)
// DeleteFilterRule removes a filtering rule previously added via
// AddFilterRule. The rule's own type identifies whether it lives
// in the peer (input) or route (forward) path.
DeleteFilterRule(rule Rule) error
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination Network,
proto Protocol,
sPort, dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
@@ -198,9 +185,8 @@ type Manager interface {
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
// GenKey builds the rule id for this pair from the given format.
func (p RouterPair) GenKey(format string) RuleID {
return RuleID(fmt.Sprintf(format, p.ID, p.Inverse))
func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
}
// LegacyManager defines the interface for legacy management operations
@@ -256,20 +242,6 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged
}
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
// plain v4 form, shifting the prefix length out of the 96-bit mapped
// range. Other prefixes are returned unchanged. Keeping prefixes
// unmapped ensures v4 rules match consistently and the match builders
// read the correct address length.
func UnmapPrefix(p netip.Prefix) netip.Prefix {
addr := p.Addr()
if !addr.Is4In6() {
return p
}
bits := max(p.Bits()-96, 0)
return netip.PrefixFrom(addr.Unmap(), bits)
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {

View File

@@ -13,13 +13,13 @@ type ForwardRule struct {
TranslatedPort Port
}
func (r ForwardRule) ID() RuleID {
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return RuleID(id)
return id
}
func (r ForwardRule) String() string {

View File

@@ -40,7 +40,7 @@ func (h Set) Comment() string {
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
prefixes = slices.Clone(prefixes)
// sort for consistent naming
SortPrefixes(prefixes)
hash := sha256.New()

View File

@@ -0,0 +1,713 @@
package nftables
import (
"bytes"
"fmt"
"net"
"slices"
"strconv"
"strings"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
// rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
)
const flushError = "flush: %w"
type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string
af addrFamily
workTable *nftables.Table
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, fmt.Errorf("create nf conn: %w", err)
}
return &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
}, nil
}
func (m *AclManager) init(workTable *nftables.Table) error {
m.workTable = workTable
return m.createDefaultChains()
}
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var ipset *nftables.Set
if ipsetName != "" {
var err error
ipset, err = m.addIpToSet(ipsetName, ip)
if err != nil {
return nil, err
}
}
newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
if err != nil {
return nil, err
}
newRules = append(newRules, ioRule)
return newRules, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.nftSet == nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
return err
}
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ips) > 0 {
return nil
}
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err
}
delete(m.rules, r.ID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.rConn.FlushSet(r.nftSet)
m.rConn.DelSet(r.nftSet)
m.ipsetStore.deleteIpset(r.nftSet.Name)
return nil
}
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainInputRules,
Position: 0,
Exprs: expIn,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *AclManager) Flush() error {
if err := m.flushWithBackoff(); err != nil {
return err
}
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (m *AclManager) addIOFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
nftRule: r.nftRule,
mangleRule: r.mangleRule,
nftSet: r.nftSet,
ruleID: r.ruleID,
ip: ip,
}, nil
}
var expressions []expr.Any
if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.protoOffset,
Len: uint32(1),
})
protoData, err := m.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: []byte{protoData},
})
}
rawIP := ipToBytes(ip, m.af)
// check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.srcAddrOffset,
Len: m.af.addrLen,
},
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
},
)
}
}
expressions = append(expressions, applyPort(sPort, true)...)
expressions = append(expressions, applyPort(dPort, false)...)
mainExpressions := slices.Clone(expressions)
switch action {
case firewall.ActionAccept:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(ruleId)
chain := m.chainInputRules
rule := &nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: mainExpressions,
UserData: userData,
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var nftRule *nftables.Rule
if action == firewall.ActionDrop {
nftRule = m.rConn.InsertRule(rule)
} else {
nftRule = m.rConn.AddRule(rule)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
}
ruleStruct := &Rule{
nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = ruleStruct
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return ruleStruct, nil
}
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err)
}
m.chainInputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
// Chain is created by route manager
// TODO: move creation to a common place
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) createChain(name string) *nftables.Chain {
chain := &nftables.Chain{
Name: name,
Table: m.workTable,
}
chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain
}
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: m.workTable,
Hooknum: hookNum,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
return m.rConn.AddChain(chain)
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return nil
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ipToBytes(ip, m.af)
if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
m.ipsetStore.newIpset(ipset.Name)
}
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
return ipset, nil
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
m.ipsetStore.AddIpToSet(ipset.Name, ip)
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
return ipset, nil
}
// createSet in given table by name
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: m.af.setKeyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
}
func (m *AclManager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if m.workTable == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(m.workTable, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) == 0 {
continue
}
split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])]
if ok {
if mangle {
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
}
}
return nil
}
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" + string(proto) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipset == nil {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipset.Name + rulesetID
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}
// ipToBytes converts net.IP to the correct byte length for the address family.
func ipToBytes(ip net.IP, af addrFamily) []byte {
if af.addrLen == 4 {
return ip.To4()
}
return ip.To16()
}

View File

@@ -1,882 +0,0 @@
//go:build !android
package nftables
import (
"bytes"
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePostrouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
r.addPostroutingRules()
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
// Kernel routing opens both INPUT and FORWARD.
if err := r.openInterface(true); err != nil {
log.Errorf("failed to open interface in foreign chains: %s", err)
}
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *family) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return nil
}
// openInterface adds passthrough accept rules for the NetBird interface to the
// kernel's filter table and external chains so they don't drop our traffic.
// includeForward also opens the FORWARD chains (kernel routing); when false only
// INPUT is opened, which is all the userspace router needs since it never
// forwards in the kernel.
func (r *family) openInterface(includeForward bool) error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(includeForward); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(includeForward); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) acceptFilterTableRules(includeForward bool) error {
if r.filterTable == nil {
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept input/forward rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available.
// Use the correct protocol (iptables vs ip6tables) for the address family.
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable, includeForward)
}
if err := r.acceptFilterRulesIptables(ipt, includeForward); err != nil {
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable, includeForward)
}
return nil
}
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables, includeForward bool) error {
var merr *multierror.Error
if includeForward {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables forward rule: %v", rule)
}
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *family) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *family) acceptFilterRulesNftables(table *nftables.Table, includeForward bool) error {
intf := ifname(r.wgIface.Name())
if includeForward {
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
}
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *family) acceptExternalChainsRules(includeForward bool) error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
r.applyExternalChainAccept(chain, intf, includeForward)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte, includeForward bool) {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
return
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
if includeForward {
r.insertForwardAcceptRules(chain, intf)
}
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
r.insertForwardIifRule(chain, intf, existing)
r.insertForwardOifEstablishedRule(chain, intf, existing)
}
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleIif] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
})
}
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleOif] {
return
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: append(exprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
})
}
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
if existing[userDataAcceptInputRule] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
})
}
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return nil, fmt.Errorf("list rules: %w", err)
}
present := map[string]bool{}
for _, rule := range rules {
if !isNetbirdAcceptRuleTag(rule.UserData) {
continue
}
present[string(rule.UserData)] = true
}
return present, nil
}
func isNetbirdAcceptRuleTag(userData []byte) bool {
switch string(userData) {
case userDataAcceptForwardRuleIif,
userDataAcceptForwardRuleOif,
userDataAcceptInputRule:
return true
}
return false
}
func (r *family) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeFilterTableRules() error {
if r.filterTable == nil {
return nil
}
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return nil
}
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != table.Name {
continue
}
if chain.Name != chainNameForward && chain.Name != chainNameInput {
continue
}
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *family) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
var merr *multierror.Error
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove rules from external chain %s/%s: %w", chain.Table.Name, chain.Name, err))
continue
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush external chain %s/%s: %w", chain.Table.Name, chain.Name, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *family) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil {
log.Debugf("list chains for family %d: %v", family, err)
continue
}
for _, chain := range allChains {
if r.isExternalChain(chain) {
chains = append(chains, chain)
}
}
}
return chains
}
func (r *family) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
}
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
}
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (r *family) Flush() error {
if err := r.flushWithBackoff(); err != nil {
return err
}
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (r *family) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if r.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
nfRule := r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := r.conn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (r *family) createDefaultChains() (err error) {
// chainNameInputRules
chain := r.createChain(chainNameInputRules)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err)
}
r.chainInputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
r.chainPrerouting = r.chains[chainNameManglePrerouting]
r.addFwmarkToForward(chainFwFilter)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: r.routingFwChainName,
},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chainFwFilter,
Exprs: expressions,
})
}
func (r *family) createChain(name string) *nftables.Chain {
chain := &nftables.Chain{
Name: name,
Table: r.workTable,
}
chain = r.conn.AddChain(chain)
insertReturnTrafficRule(r.conn, r.workTable, chain)
return chain
}
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: r.workTable,
Hooknum: hookNum,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
return r.conn.AddChain(chain)
}
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: expressions,
})
return nil
}
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (r *family) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if r.workTable == nil || chain == nil {
return nil
}
list, err := r.conn.GetRules(r.workTable, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) == 0 {
continue
}
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
if !ok {
continue
}
if mangle {
if pr.mangleRule != nil {
*pr.mangleRule = *rule
}
} else {
*pr.nftRule = *rule
}
}
return nil
}

View File

@@ -1,550 +0,0 @@
//go:build !android
package nftables
import (
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
ruleID := rule.ID()
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := r.af.protoNum(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
// Request forwarding once the rule is about to be installed, releasing
// it if a later step fails so the refcount tracks the real rules.
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
if err := r.addDnatRedirect(rule, protoNum, ruleID); err != nil {
r.releaseForwarding()
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleID)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
r.releaseForwarding()
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleID, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleID + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleID+dnatSuffix] = dnatRule
return nil
}
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleID firewall.RuleID, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
natTable := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
dnatRule := &nftables.Rule{
Table: natTable,
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: natTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleID + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleID+dnatSuffix] = dnatRule
return nil
}
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.dstAddrOffset,
Len: r.af.addrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleID + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleID+snatSuffix] = masqRule
}
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
ruleID := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
var needsFlush bool
var found bool
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
found = true
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleID+dnatSuffix)
delete(r.rules, ruleID+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleID+snatSuffix]; exists {
found = true
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleID+snatSuffix)
delete(r.rules, ruleID+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
}
if merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
delete(r.rules, ruleID+dnatSuffix)
delete(r.rules, ruleID+snatSuffix)
// Release once, only if the rule was present and removed.
if found {
r.releaseForwarding()
}
return nil
}
// releaseForwarding drops one IP forwarding reference, logging any error.
func (r *family) releaseForwarding() {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("release IP forwarding: %v", err)
}
}
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(originalPort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *family) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(originalPort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -1,249 +0,0 @@
//go:build !android
package nftables
import (
"fmt"
"net/netip"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
// Peer ACL chain names.
chainNameInputRules = "netbird-acl-input-rules"
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
flushError = "flush: %w"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
dnatSuffix firewall.RuleID = "_dnat"
snatSuffix firewall.RuleID = "_snat"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
)
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
)
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
// family holds the per-address-family nftables state. One instance
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
// single family; the top-level Manager owns one for v4 and another
// for v6. The name predates the peer-ACL absorption; it's effectively
// the per-family backend now.
type family struct {
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// filters holds peer + route filter rules keyed by content hash.
// AddFilterRule writes here; DeleteFilterRule looks up by id.
filters map[firewall.RuleID]*Rule
// rules holds NAT, DNAT, and external accept rules (auxiliary
// plumbing that isn't a filter rule).
rules map[firewall.RuleID]*nftables.Rule
// Peer ACL chain handles.
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
routingFwChainName string
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
af addrFamily
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*family, error) {
r := &family{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
filters: make(map[firewall.RuleID]*Rule),
rules: make(map[firewall.RuleID]*nftables.Rule),
routingFwChainName: chainNameRoutingFw,
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
log.Debugf("ip filter table not found: %v", err)
}
return r, nil
}
func (r *family) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptFilterRules(); err != nil {
log.Errorf("failed to clean up rules from filter table: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
if err := r.createDefaultChains(); err != nil {
return fmt.Errorf("create default acl chains: %w", err)
}
return nil
}
// Reset cleans existing nftables filter table rules from the system
func (r *family) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
var merr *multierror.Error
if err := r.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
if err != nil {
return nil, fmt.Errorf("list tables: %w", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *family) iptablesProto() iptables.Protocol {
if r.af.tableFamily == nftables.TableFamilyIPv6 {
return iptables.ProtocolIPv6
}
return iptables.ProtocolIPv4
}
func (r *family) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[firewall.RuleID]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
newRules[firewall.RuleID(rule.UserData)] = rule
}
}
}
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,441 +0,0 @@
//go:build !android
package nftables
import (
"fmt"
"net"
"net/netip"
"slices"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
)
// AddFilterRule installs one nftables packet-filter rule. With
// destination empty the rule goes to the peer ACL input chain plus a
// paired prerouting mangle rule for the redirect mark. With
// destination set (prefix or named set) it goes to the route ACL
// forward chain. Multi-source rules collapse to one nftables rule
// backed by the shared refcounted hash:net set.
func (r *family) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
isRoute := !destination.IsZero()
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existing, ok := r.filters[ruleID]; ok {
return existing, nil
}
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
var exprs []expr.Any
if isRoute {
exprs, err = r.buildRouteFilterExprs(srcExprs, destination, proto, sPort, dPort)
} else {
exprs, err = r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
}
if err != nil {
r.dropNetworkMatch(srcExprs)
return nil, err
}
mainExprs := slices.Clone(exprs)
verdict := expr.VerdictAccept
if action == firewall.ActionDrop {
verdict = expr.VerdictDrop
}
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
chain := r.chainInputRules
if isRoute {
chain = r.chains[chainNameRoutingFw]
}
userData := []byte(ruleID)
nftRule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: mainExprs,
UserData: userData,
}
if action == firewall.ActionDrop {
nftRule = r.conn.InsertRule(nftRule)
} else {
nftRule = r.conn.AddRule(nftRule)
}
if err := r.conn.Flush(); err != nil {
r.dropNetworkMatch(exprs)
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{
nftRule: nftRule,
sources: sources,
id: ruleID,
}
if !isRoute {
rule.mangleRule = r.createPreroutingRule(exprs, userData)
}
r.filters[ruleID] = rule
log.Debugf("added filter rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
sources, destination, proto, sPort, dPort, action)
return rule, nil
}
// buildPeerFilterExprs assembles the input-chain (peer ACL) match: the
// IP-header protocol byte read via Payload, then source, then ports
// (no counter), matching the historical peer shape so per-rule kernel
// state is identical to pre-unification.
func (r *family) buildPeerFilterExprs(
srcExprs []expr.Any,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
) ([]expr.Any, error) {
var exprs []expr.Any
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.protoOffset,
Len: 1,
},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
)
}
exprs = append(exprs, srcExprs...)
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
return exprs, nil
}
// buildRouteFilterExprs assembles the forward-chain (route ACL) match:
// source, then destination, then optional proto/ports, then a counter.
func (r *family) buildRouteFilterExprs(
srcExprs []expr.Any,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
) ([]expr.Any, error) {
exprs := append([]expr.Any{}, srcExprs...)
destExprs, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExprs...)
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
if err != nil {
r.dropNetworkMatch(destExprs)
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
)
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
return exprs, nil
}
func (r *family) hasRule(id firewall.RuleID) bool {
_, ok := r.filters[id]
return ok
}
func (r *family) hasDNATRule(id firewall.RuleID) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
// DeleteFilterRule removes a previously installed filter rule. Source
// set references are recovered from the stored rule's expressions via
// findSets and dropped from the shared refcounter.
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
ruleID := rule.ID()
pr, ok := r.filters[ruleID]
if !ok {
log.Debugf("filter rule %s not found", ruleID)
return nil
}
// A freshly added rule carries no handle until it is read back from
// the kernel, and Flush only refreshes the peer chains. Pull live
// handles for this rule's chain before deciding it is stale so route
// rules (which Flush never refreshes) can actually be deleted.
if pr.nftRule.Handle == 0 {
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
log.Warnf("refresh handles for chain %s: %v", pr.nftRule.Chain.Name, err)
}
if pr.mangleRule != nil {
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
log.Warnf("refresh mangle handles: %v", err)
}
}
}
if pr.nftRule.Handle == 0 {
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
r.dropNetworkMatch(pr.nftRule.Exprs)
delete(r.filters, ruleID)
return nil
}
if err := r.conn.DelRule(pr.nftRule); err != nil {
log.Errorf("queue rule delete: %v", err)
}
if pr.mangleRule != nil {
if err := r.conn.DelRule(pr.mangleRule); err != nil {
log.Errorf("queue mangle rule delete: %v", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete %s: %w", ruleID, err)
}
r.dropNetworkMatch(pr.nftRule.Exprs)
delete(r.filters, ruleID)
return nil
}
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
if r.ipsetCounter == nil {
return nil
}
sets := findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// dropNetworkMatch undoes whatever the source/destination match
// reserved. Safe to call when the spec is empty or holds only inline
// matchers.
func (r *family) dropNetworkMatch(exprs []expr.Any) {
if r.ipsetCounter == nil {
return
}
for _, e := range exprs {
lookup, ok := e.(*expr.Lookup)
if !ok {
continue
}
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
}
}
}
func (r *family) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
side := "destination"
if isSource {
side = "source"
}
return nil, fmt.Errorf("%s set: %w", side, err)
}
return exprs, nil
}
if network.IsPrefix() {
return prefixMatchExprs(r.af, network.Prefix, isSource), nil
}
return nil, nil
}
// prefixMatchExprs is the family-aware match sequence for a CIDR
// prefix. /0 returns nil; a host prefix (full bit length for the
// family) skips the bitwise step since the mask is all-ones. Shared
// between family and aclManager so both treat single prefixes
// identically.
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
offset := af.dstAddrOffset
if isSource {
offset = af.srcAddrOffset
}
ones := prefix.Bits()
if ones == 0 {
return nil
}
payload := &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: af.addrLen,
}
cmp := &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
}
if ones == af.totalBits {
return []expr.Any{payload, cmp}
}
mask := net.CIDRMask(ones, af.totalBits)
xor := make([]byte, af.addrLen)
return []expr.Any{
payload,
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: af.addrLen,
Mask: mask,
Xor: xor,
},
cmp,
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
// src
offset := uint32(2)
if isSource {
// dst
offset = 0
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
exprs = append(exprs,
&expr.Range{
Op: expr.CmpOpEq,
Register: 1,
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
},
)
} else {
for i, p := range port.Values {
if i > 0 {
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(p),
})
}
}
return exprs
}
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
// sourceNetwork classifies a source-prefix list into the firewall.Network
// shape the rest of the spec-builder consumes: empty for match-any, a
// single prefix inline, or an ipset for multiple sources.
func sourceNetwork(sources []netip.Prefix) firewall.Network {
switch {
case len(sources) == 0:
return firewall.Network{}
case len(sources) == 1 && sources[0].Bits() == 0:
return firewall.Network{}
case len(sources) == 1:
return firewall.Network{Prefix: sources[0]}
default:
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
}
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}
// findSets scans an nftables rule's expressions for expr.Lookup and
// returns the named sets in occurrence order. Used at delete time to
// drop ipsetCounter references; peer and route ACLs go through it.
func findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName)
}
}
return sets
}

View File

@@ -1,90 +0,0 @@
//go:build integration && !android
package nftables
import (
"bytes"
"os"
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
)
// TestInterfaceAllowerInputOnly verifies the userspace-mode allower opens the
// interface on the INPUT hook of foreign chains only (not FORWARD, since the
// userspace router never forwards in the kernel), creates no netbird work
// table, and removes its rules on Close.
func TestInterfaceAllowerInputOnly(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("root required")
}
require.False(t, ipTableExists(t, getTableName()), "precondition: no stale netbird table")
conn := &nftables.Conn{}
extTable := conn.AddTable(&nftables.Table{Name: "nbtest_extchains", Family: nftables.TableFamilyINet})
inputChain := conn.AddChain(&nftables.Chain{
Name: "ext_input", Table: extTable,
Hooknum: nftables.ChainHookInput, Priority: nftables.ChainPriorityFilter, Type: nftables.ChainTypeFilter,
})
forwardChain := conn.AddChain(&nftables.Chain{
Name: "ext_forward", Table: extTable,
Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityFilter, Type: nftables.ChainTypeFilter,
})
require.NoError(t, conn.Flush(), "create external table and chains")
t.Cleanup(func() {
c := &nftables.Conn{}
c.DelTable(extTable)
_ = c.Flush()
})
allower, err := NewInterfaceAllower(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "create allower")
require.NoError(t, allower.Apply(), "apply")
require.True(t, chainHasUserData(t, extTable, inputChain, userDataAcceptInputRule),
"external INPUT chain should get the accept rule")
require.Len(t, listRules(t, extTable, forwardChain), 0,
"external FORWARD chain must not be opened in userspace mode")
require.False(t, ipTableExists(t, getTableName()),
"allower must not create a netbird work table")
require.NoError(t, allower.Close(), "close")
require.False(t, chainHasUserData(t, extTable, inputChain, userDataAcceptInputRule),
"accept rule should be removed on close")
}
func listRules(t *testing.T, table *nftables.Table, chain *nftables.Chain) []*nftables.Rule {
t.Helper()
c := &nftables.Conn{}
rules, err := c.GetRules(table, chain)
require.NoError(t, err)
return rules
}
func chainHasUserData(t *testing.T, table *nftables.Table, chain *nftables.Chain, ud string) bool {
for _, r := range listRules(t, table, chain) {
if bytes.Equal(r.UserData, []byte(ud)) {
return true
}
}
return false
}
func ipTableExists(t *testing.T, name string) bool {
t.Helper()
c := &nftables.Conn{}
for _, fam := range []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyIPv6} {
tbls, err := c.ListTablesOfFamily(fam)
require.NoError(t, err)
for _, tb := range tbls {
if tb.Name == name {
return true
}
}
}
return false
}

View File

@@ -1,114 +0,0 @@
package nftables
import (
"fmt"
"github.com/google/nftables"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// InterfaceAllower opens the NetBird interface in the kernel's filter table and
// external chains and keeps them reconciled via a netlink monitor, so the host
// firewall doesn't drop traffic the NetBird firewall handles. It is used by the
// userspace firewall, where routing happens in the forwarder, so only INPUT is
// opened (the userspace router never forwards in the kernel).
//
// It owns its own families/connection and never creates a netbird work table.
// firewalld trust is handled by the caller, not here. Its operations are serial
// (Apply before the monitor starts; reconciles run on the single monitor
// goroutine; Close stops the monitor before removing), so it needs no locking.
//
// TODO: this opens nftables and the iptables-nft filter table (detected via
// nft), but not a legacy-iptables ruleset running in parallel with nftables.
// Such a host would keep its legacy filter chains closed for the interface.
type InterfaceAllower struct {
family4 *family
family6 *family
extMonitor *externalChainMonitor
}
// NewInterfaceAllower builds an allower for the given interface. It returns an
// error when nftables is unavailable (e.g. an iptables-legacy host), so the
// caller can fall back to firewalld trust.
func NewInterfaceAllower(wgIface iFaceMapper, mtu uint16) (*InterfaceAllower, error) {
tableName := getTableName()
family4, err := newFamily(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create family: %w", err)
}
// Probe nftables availability before committing to this backend.
if _, err := family4.conn.ListChainsOfTableFamily(nftables.TableFamilyINet); err != nil {
return nil, fmt.Errorf("nftables not available: %w", err)
}
a := &InterfaceAllower{family4: family4}
if wgIface.Address().HasIPv6() {
family6, err := newFamily(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create v6 family: %w", err)
}
a.family6 = family6
}
a.extMonitor = newExternalChainMonitor(a)
return a, nil
}
// Apply opens the interface (INPUT only) in the foreign filter chains and starts
// reconciling them on nftables changes.
func (a *InterfaceAllower) Apply() error {
var merr *multierror.Error
for _, f := range a.families() {
// Remove any stale accepts first so a prior unclean exit (e.g. SIGKILL,
// where Close never ran) is recovered deterministically rather than
// accumulating duplicate rules on the iptables filter table.
if err := f.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("clean stale accept rules: %w", err))
}
if err := f.openInterface(false); err != nil {
merr = multierror.Append(merr, err)
}
}
a.extMonitor.start()
return nberrors.FormatErrorOrNil(merr)
}
// families returns the configured address families (v4, and v6 when present).
func (a *InterfaceAllower) families() []*family {
families := []*family{a.family4}
if a.family6 != nil {
families = append(families, a.family6)
}
return families
}
// reconcileExternalChains re-applies the INPUT accepts to external chains. It
// implements externalChainReconciler for the monitor.
func (a *InterfaceAllower) reconcileExternalChains() error {
var merr *multierror.Error
for _, f := range a.families() {
if err := f.acceptExternalChainsRules(false); err != nil {
merr = multierror.Append(merr, err)
}
}
return nberrors.FormatErrorOrNil(merr)
}
// Close stops the monitor and removes the accept rules.
func (a *InterfaceAllower) Close() error {
a.extMonitor.stop()
var merr *multierror.Error
for _, f := range a.families() {
if err := f.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, err)
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -1,210 +0,0 @@
//go:build !android
package nftables
import (
"encoding/binary"
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
set: set,
prefixes: prefixes,
})
if err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return r.getIpSetExprs(ref, isSource)
}
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
prefixes := firewall.MergeIPRanges(input.prefixes)
nfset := &nftables.Set{
Name: setName,
Comment: input.set.Comment(),
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: r.af.setKeyType,
}
elements := r.convertPrefixesToSet(prefixes)
nElements := len(elements)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
// The set is committed now. If a later batch fails, destroy it: the
// refcounter records nothing on a create-callback error, so it would
// otherwise leak, and a partial source set fails-open for deny rules.
if err := r.addRemainingElements(nfset, elements, maxElements); err != nil {
if derr := r.deleteIpSet(setName, nfset); derr != nil {
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
}
return nil, err
}
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}
// addRemainingElements adds element batches beyond the initial one in
// maxElements-sized chunks, flushing each. Called after the set has been
// created with its first batch.
func (r *family) addRemainingElements(nfset *nftables.Set, elements []nftables.SetElement, maxElements int) error {
nElements := len(elements)
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd := min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, nfset.Name, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
}
return nil
}
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range prefixes {
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
// For a /0 the last address is the broadcast and its Next() overflows
// to an invalid Addr with an empty key, so wrap to the zero address,
// which nftables reads as the open end of a full-range interval.
var lastKey []byte
if prefix.Bits() == 0 {
lastKey = make([]byte, r.af.addrLen)
} else {
lastKey = calculateLastIP(prefix).Next().AsSlice()
}
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
elements = append(elements,
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastKey, IntervalEnd: true},
)
}
return elements
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
masked := prefix.Masked()
if masked.Addr().Is4() {
hostMask := ^uint32(0) >> masked.Bits()
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// IPv6: set host bits to all 1s
b := masked.Addr().As16()
bits := masked.Bits()
for i := bits; i < 128; i++ {
b[i/8] |= 1 << (7 - i%8)
}
return netip.AddrFrom16(b)
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
if err != nil {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
// Overlapping prefixes (e.g. duplicate resolved addresses) make the
// interval set reject the batch, so merge them as createIpSet does.
prefixes = firewall.MergeIPRanges(prefixes)
elements := r.convertPrefixesToSet(prefixes)
// Add in batches sized like createIpSet so a large update does not
// exceed the netlink message size limit.
maxElements := maxPrefixesSet * 2
for start := 0; start < len(elements); start += maxElements {
end := min(start+maxElements, len(elements))
if err := r.conn.SetAddElements(nfset, elements[start:end]); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
}
log.Debugf("updated set %s with %d prefixes", set.HashedName(), len(prefixes))
return nil
}
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset by default
offset := r.af.dstAddrOffset
if isSource {
// src offset
offset = r.af.srcAddrOffset
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: r.af.addrLen,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -1,36 +0,0 @@
package nftables
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConvertPrefixesToSetWildcard verifies that a /0 prefix produces a
// usable interval. The last address of a /0 is the broadcast, whose Next()
// overflows to an invalid Addr with an empty key; the IntervalEnd must wrap
// to the zero address instead so nftables sees a full-range interval.
func TestConvertPrefixesToSetWildcard(t *testing.T) {
tests := []struct {
name string
af addrFamily
prefix string
}{
{"IPv4 /0", afIPv4, "0.0.0.0/0"},
{"IPv6 /0", afIPv6, "::/0"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &family{af: tt.af}
elements := r.convertPrefixesToSet([]netip.Prefix{netip.MustParsePrefix(tt.prefix)})
require.Len(t, elements, 2, "expected start and interval-end element")
assert.False(t, elements[0].IntervalEnd, "first element is the interval start")
assert.True(t, elements[1].IntervalEnd, "second element is the interval end")
assert.Len(t, elements[1].Key, int(tt.af.addrLen), "interval-end key must be a zero address, not empty")
})
}
}

View File

@@ -0,0 +1,85 @@
package nftables
import (
"net"
)
type ipsetStore struct {
ipsetReference map[string]int
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsetReference: make(map[string]int),
ipsets: make(map[string]map[string]struct{}),
}
}
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
s.ipsetReference[ipsetName] = 0
ipList := make(map[string]struct{})
s.ipsets[ipsetName] = ipList
return ipList
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsetReference, ipsetName)
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
delete(ipList, ip.String())
}
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
ipList[ip.String()] = struct{}{}
}
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return false
}
_, ok = ipList[ip.String()]
return ok
}
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
s.ipsetReference[ipsetName]++
}
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
r, ok := s.ipsetReference[ipsetName]
if !ok {
return
}
if r == 0 {
return
}
s.ipsetReference[ipsetName]--
}
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
if _, ok := s.ipsetReference[ipsetName]; !ok {
return false
}
if s.ipsetReference[ipsetName] == 0 {
return false
}
return true
}

View File

@@ -3,6 +3,7 @@ package nftables
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"sync"
@@ -15,6 +16,7 @@ import (
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -43,17 +45,18 @@ type iFaceMapper interface {
Address() wgaddr.Address
}
// Manager of nftables firewall. Per-family state (peer ACLs, route
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
// by family and provides the public firewall.Manager surface.
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
rConn *nftables.Conn
wgIface iFaceMapper
family4 *family
// IPv6 counterpart, nil when no v6 overlay.
family6 *family
router *router
aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
@@ -72,9 +75,14 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
var err error
m.family4, err = newFamily(workTable, wgIface, mtu)
m.router, err = newRouter(workTable, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create family: %w", err)
return nil, fmt.Errorf("create router: %w", err)
}
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
}
if wgIface.Address().HasIPv6() {
@@ -92,21 +100,26 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.family6, err = newFamily(workTable6, wgIface, mtu)
m.router6, err = newRouter(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 family: %w", err)
return fmt.Errorf("create v6 router: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.family6.ipFwdState = m.family4.ipFwdState
m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.family6 != nil
return m.router6 != nil
}
func (m *Manager) initIPv6() error {
@@ -115,8 +128,12 @@ func (m *Manager) initIPv6() error {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.family6.init(workTable6); err != nil {
return fmt.Errorf("v6 family init: %w", err)
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
}
return nil
@@ -139,20 +156,19 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// reconcileExternalChains re-applies passthrough accept rules to external
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
// tables or chains appear (e.g. after firewalld reloads). Kernel routing opens
// both INPUT and FORWARD.
// tables or chains appear (e.g. after firewalld reloads).
func (m *Manager) reconcileExternalChains() error {
m.mutex.Lock()
defer m.mutex.Unlock()
var merr *multierror.Error
if m.family4 != nil {
if err := m.family4.acceptExternalChainsRules(true); err != nil {
if m.router != nil {
if err := m.router.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
}
}
if m.hasIPv6() {
if err := m.family6.acceptExternalChainsRules(true); err != nil {
if err := m.router6.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
}
}
@@ -171,8 +187,12 @@ func (m *Manager) initFirewall() (err error) {
}
}()
if err := m.family4.init(workTable); err != nil {
return fmt.Errorf("family init: %w", err)
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
return fmt.Errorf("acl manager init: %w", err)
}
if m.hasIPv6() {
@@ -200,7 +220,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.family4.mtu,
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -215,12 +235,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
func (m *Manager) rollbackInit() {
if err := m.family4.Reset(); err != nil {
log.Warnf("rollback family: %v", err)
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
}
if m.hasIPv6() {
if err := m.family6.Reset(); err != nil {
log.Warnf("rollback v6 family: %v", err)
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router: %v", err)
}
}
if err := m.cleanupNetbirdTables(); err != nil {
@@ -231,77 +251,118 @@ func (m *Manager) rollbackInit() {
}
}
// AddFilterRule installs a packet-filtering rule.
// AddPeerFiltering rule to the firewall
//
// Destination semantics: zero Network → input chain (peer ACL);
// set Network → forward chain (route ACL).
//
// Sources are a single address family; the rule is dispatched to the
// matching per-family backend.
func (m *Manager) AddFilterRule(
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
ipsetName string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
fam := m.family4
if isIPv6Rule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
}
fam = m.family6
if ip.To4() != nil {
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
// DeleteFilterRule removes a filtering rule. The owning family is found
// by id, refreshing from the kernel if the in-memory caches miss so a
// stale cache cannot leak the rule. family.DeleteFilterRule is idempotent
// when the id is absent.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
fam, err := m.familyForRuleID(rule.ID(), (*family).hasRule)
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
}
return m.aclManager.DeletePeerRule(rule)
}
func isIPv6Rule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
// router; the cached maps are normally authoritative, so the kernel is only
// consulted when neither map knows about the rule.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
id := rule.ID()
r, err := m.routerForRuleID(id, (*router).hasRule)
if err != nil {
return err
}
return fam.DeleteFilterRule(rule)
return r.DeleteRouteRule(rule)
}
// familyForRuleID picks the family holding the rule with the given id, using
// routerForRuleID picks the router holding the rule with the given id, using
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
// from the kernel once and re-checks before falling back to the v4 family.
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool) (*family, error) {
if has(m.family4, id) {
return m.family4, nil
// from the kernel once and re-checks before falling back to the v4 router.
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
if has(m.router, id) {
return m.router, nil
}
if m.hasIPv6() && has(m.family6, id) {
return m.family6, nil
if m.hasIPv6() && has(m.router6, id) {
return m.router6, nil
}
if !m.hasIPv6() {
return m.family4, nil
return m.router, nil
}
if err := m.family4.refreshRulesMap(); err != nil {
if err := m.router.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v4 rules: %w", err)
}
if err := m.family6.refreshRulesMap(); err != nil {
if err := m.router6.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v6 rules: %w", err)
}
if has(m.family6, id) && !has(m.family4, id) {
return m.family6, nil
if has(m.router6, id) && !has(m.router, id) {
return m.router6, nil
}
return m.family4, nil
return m.router, nil
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -320,10 +381,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddNatRule(pair)
return m.router6.AddNatRule(pair)
}
if err := m.family4.AddNatRule(pair); err != nil {
if err := m.router.AddNatRule(pair); err != nil {
return err
}
@@ -335,7 +396,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// so the eventual cleanup still works.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.family6.AddNatRule(v6Pair); err != nil {
if err := m.router6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
@@ -351,18 +412,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return nil
}
return m.family6.RemoveNatRule(pair)
return m.router6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.family4.RemoveNatRule(pair); err != nil {
if err := m.router.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
@@ -370,13 +431,46 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
//
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
// which doesn't override DROP rules in external tables (e.g. firewalld).
// Should add passthrough rules to external chains (like the native mode router's
// addExternalChainsRules does) for both the netbird table family and inet tables.
// The netbird table itself is fine (routing chains already exist there), but
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
if m.hasIPv6() {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.family6, isLegacy)
return firewall.SetLegacyManagement(m.router6, isLegacy)
}
return nil
}
@@ -390,13 +484,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.family4.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
}
if m.hasIPv6() {
if err := m.family6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
}
}
@@ -436,14 +530,14 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
@@ -457,12 +551,12 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.family4.Flush(); err != nil {
if err := m.aclManager.Flush(); err != nil {
return err
}
if m.hasIPv6() {
if err := m.family6.Flush(); err != nil {
if err := m.aclManager6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
@@ -483,9 +577,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddDNATRule(rule)
return m.router6.AddDNATRule(rule)
}
return m.family4.AddDNATRule(rule)
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
@@ -493,7 +587,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule)
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
if err != nil {
return err
}
@@ -514,12 +608,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
}
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
@@ -536,9 +630,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
@@ -550,9 +644,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
@@ -564,9 +658,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
@@ -578,9 +672,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
const (
@@ -809,14 +903,3 @@ func getEstablishedExprs(register uint32) []expr.Any {
},
}
}
// isIPv6Rule reports whether the rule belongs to the v6 table. For a
// prefix destination the destination family decides; otherwise the
// (single-family) sources do, since management duplicates rules per
// family.
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}

View File

@@ -1,5 +1,3 @@
//go:build integration && !android
package nftables
import (
@@ -72,13 +70,13 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 2, "expected 2 rules")
@@ -149,12 +147,15 @@ func TestNftablesManager(t *testing.T) {
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
for _, r := range rule {
err = manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
@@ -179,39 +180,47 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
testClient := &nftables.Conn{}
// Add accept rule first
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for the same traffic
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
t.Logf("Found %d rules in nftables chain", len(rules))
// Single-source rules emit a direct payload+cmp on the source IP
// (no set lookup). Match by source-IP + port + verdict instead of
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
// that this test predates.
wantSrc := ip.AsSlice()
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex = -1, -1
for i, rule := range rules {
var hasSrc, hasPort80 bool
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
hasPort80 := false
var action string
for _, e := range rule.Exprs {
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
if bytes.Equal(cmp.Data, wantSrc) {
hasSrc = true
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
switch lookup.SetName {
case "accept-http":
hasAcceptHTTPSet = true
case "deny-http":
hasDenyHTTPSet = true
}
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
hasPort80 = true
}
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
switch verdict.Kind {
case expr.VerdictAccept:
@@ -222,15 +231,11 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
}
}
if !hasSrc || !hasPort80 {
continue
}
switch action {
case "ACCEPT":
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
acceptRuleIndex = i
case "DROP":
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
denyRuleIndex = i
}
}
@@ -274,7 +279,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -356,10 +361,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
@@ -432,10 +437,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
@@ -545,7 +550,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
@@ -560,7 +565,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T) {
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
@@ -586,9 +591,9 @@ func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
[]netip.Prefix{},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//go:build integration && !android
//go:build !android
package nftables
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and family for all chains to be present
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
nftablesTestingClient := &nftables.Conn{}
rtr := manager.family4
rtr := manager.router
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
@@ -90,9 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
}
// Build CIDR matching expressions
testRouter := &family{af: afIPv4}
sourceExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Source.Prefix, true)
destExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Destination.Prefix, false)
testRouter := &router{af: afIPv4}
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order
// nolint:gocritic
@@ -100,14 +100,14 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNameManglePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
// Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
@@ -135,19 +135,19 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
rtr := manager.family4
rtr := manager.router
// First add the NAT rule using the family's method
// First add the NAT rule using the router's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
// Verify the rule was added
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
@@ -200,11 +200,11 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *family) {
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
@@ -314,16 +314,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddFilterRule failed")
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
})
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
require.True(t, ok, "Rule not found in filters map")
rule := stored.nftRule
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
@@ -339,7 +339,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule
for _, rule := range rules {
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
if string(rule.UserData) == ruleKey.ID() {
nftRule = rule
break
}
@@ -367,12 +367,12 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset family")
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
@@ -509,42 +509,6 @@ func TestNftablesCreateIpSet(t *testing.T) {
}
}
// TestNftablesUpdateSetMergesOverlapping verifies that UpdateSet merges
// overlapping prefixes before adding them. An interval set rejects
// overlapping elements, so without the merge a batch holding a /32 already
// covered by a /24, or a duplicate address as DNS resolution can produce,
// would fail.
func TestNftablesUpdateSetMergesOverlapping(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "create work table")
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "create family")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "reset family")
}()
initial := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
set := firewall.NewPrefixSet(initial)
created, err := r.createIpSet(set.HashedName(), setInput{prefixes: initial})
require.NoError(t, err, "create ip set")
require.NotNil(t, created)
overlapping := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("192.168.1.1/32"),
}
require.NoError(t, r.UpdateSet(set, overlapping), "UpdateSet must merge overlapping prefixes")
}
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
@@ -554,11 +518,11 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset family")
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
@@ -897,13 +861,13 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
require.NoError(t, err)
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddFilterRule(
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
@@ -914,11 +878,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteFilterRule(ruleKey))
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
@@ -938,55 +902,6 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
// rule is actually removed from the kernel on delete. The route chain is
// not refreshed by Flush, so the stored rule carries a zero handle;
// DeleteFilterRule must pull live handles itself before issuing the
// delete or the kernel rule leaks. Regression test for that path.
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
ruleKey, err := r.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
countKernelRules := func() int {
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err)
n := 0
for _, rule := range list {
if string(rule.UserData) == string(ruleKey.ID()) {
n++
}
}
return n
}
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
require.NoError(t, r.DeleteFilterRule(ruleKey))
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
@@ -996,28 +911,24 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
require.NoError(t, err)
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := id.RuleID("stale-route-rule")
staleRule := &Rule{
nftRule: &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
},
id: staleKey,
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
r.filters[staleKey] = staleRule
// DeleteFilterRule should not return an error for stale handles
err = r.DeleteFilterRule(staleRule)
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
@@ -1039,7 +950,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
Masquerade: true,
}
rtr := manager.family4
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
@@ -1049,11 +960,11 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
})
// Corrupt the handle to simulate stale state
natRuleKey := pair.GenKey(firewall.PreroutingFormat)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat)
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
@@ -1068,7 +979,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
@@ -1099,7 +1010,7 @@ func TestCalculateLastIP(t *testing.T) {
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &family{af: afIPv6}
r := &router{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),

View File

@@ -1,494 +0,0 @@
//go:build !android
package nftables
import (
"fmt"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *family) rollbackRules(pair firewall.RouterPair) {
keys := []firewall.RuleID{
pair.GenKey(firewall.ForwardingFormat),
pair.GenKey(firewall.PreroutingFormat),
firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *family) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq
if pair.Inverse {
op = expr.CmpOpNeq
}
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
ruleID := pair.GenKey(firewall.PreroutingFormat)
if _, exists := r.rules[ruleID]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
}
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleID] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: exprs,
UserData: []byte(ruleID),
})
return nil
}
func (r *family) addPostroutingRules() {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
func (r *family) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.af.tableFamily == nftables.TableFamilyIPv6 {
overhead = ipv6TCPHeaderSize
}
if r.mtu <= overhead {
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
return nil
}
mss := r.mtu - overhead
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return r.conn.Flush()
}
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
var exprs []expr.Any
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
exprs = append(exprs,
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
)
ruleID := pair.GenKey(firewall.ForwardingFormat)
if _, exists := r.rules[ruleID]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleID] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: exprs,
UserData: []byte(ruleID),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.ForwardingFormat)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
return r.deleteLegacyRuleEntry(ruleID, rule)
}
// deleteLegacyRuleEntry removes one legacy forwarding rule and drops its
// ipset references. It also clears stale entries that never got a handle.
func (r *family) deleteLegacyRuleEntry(ruleID firewall.RuleID, rule *nftables.Rule) error {
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleID)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s: %w", ruleID, err)
}
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *family) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *family) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *family) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
continue
}
if err := r.deleteLegacyRuleEntry(k, rule); err != nil {
merr = multierror.Append(merr, err)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeNatRule(pair firewall.RouterPair) error {
ruleID := pair.GenKey(firewall.PreroutingFormat)
rule, exists := r.rules[ruleID]
if !exists {
log.Debugf("prerouting rule %s not found", ruleID)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleID)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleID)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}

View File

@@ -1,26 +1,21 @@
package nftables
import (
"net/netip"
"net"
"github.com/google/nftables"
"github.com/netbirdio/netbird/client/firewall/manager"
)
// Rule wraps an installed filter rule (peer or route). Source set
// membership is encoded in the rule's expressions; DeleteFilterRule
// recovers the set name via findSets so the refcounter can drop the
// right reference. mangleRule is set only for peer rules.
// Rule to handle management of rules
type Rule struct {
nftRule *nftables.Rule
mangleRule *nftables.Rule
// sources is the canonical source list this rule was created for.
sources []netip.Prefix
id manager.RuleID
nftSet *nftables.Set
ruleID string
ip net.IP
}
// ID returns the rule id
func (r *Rule) ID() manager.RuleID {
return r.id
// GetRuleID returns the rule id
func (r *Rule) ID() string {
return r.ruleID
}

View File

@@ -1,27 +0,0 @@
//go:build integration && !android
package nftables
import (
"fmt"
"net"
"net/netip"
)
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
}
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -0,0 +1,37 @@
//go:build !windows
package uspfilter
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/firewalld"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)
}
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to untrust interface in firewalld: %v", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird()
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type action string
@@ -19,20 +20,35 @@ const (
firewallRuleName = "Netbird"
)
// WindowsInterfaceAllower opens the NetBird interface in the Windows firewall
// via netsh advfirewall rules. It implements InterfaceAllower for the userspace
// firewall on Windows.
type WindowsInterfaceAllower struct {
iface Iface
// Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
if !isWindowsFirewallReachable() {
return nil
}
var merr *multierror.Error
if isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
}
}
if isFirewallRuleActive(firewallRuleName + "-v6") {
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// NewWindowsInterfaceAllower builds the Windows netsh-based interface allower.
func NewWindowsInterfaceAllower(iface Iface) *WindowsInterfaceAllower {
return &WindowsInterfaceAllower{iface: iface}
}
// Apply adds inbound-allow netsh rules for the interface's addresses.
func (a *WindowsInterfaceAllower) Apply() error {
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !isWindowsFirewallReachable() {
return nil
}
@@ -44,13 +60,13 @@ func (a *WindowsInterfaceAllower) Apply() error {
"enable=yes",
"action=allow",
"profile=any",
"localip="+a.iface.Address().IP.String(),
"localip="+m.wgIface.Address().IP.String(),
); err != nil {
return err
}
}
if v6 := a.iface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
if err := manageFirewallRule(firewallRuleName+"-v6",
addRule,
"dir=in",
@@ -66,27 +82,8 @@ func (a *WindowsInterfaceAllower) Apply() error {
return nil
}
// Close removes the netsh rules added by Apply.
func (a *WindowsInterfaceAllower) Close() error {
if !isWindowsFirewallReachable() {
return nil
}
var merr *multierror.Error
if isFirewallRuleActive(firewallRuleName) {
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
}
}
if isFirewallRuleActive(firewallRuleName + "-v6") {
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
if action == addRule {
args = append(args, extraArgs...)

View File

@@ -0,0 +1,17 @@
package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"os"
"slices"
@@ -19,18 +20,14 @@ import (
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -61,10 +58,7 @@ const (
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
// EnvForceUserspaceRouter is a deprecated alias for
// NB_FORCE_USERSPACE_FIREWALL: the userspace firewall always routes in
// userspace, so forcing one forces the other. Kept for backward
// compatibility.
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
@@ -76,19 +70,14 @@ const (
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
)
// errNotSupported is returned by firewall operations that only make sense with
// a kernel firewall (kernel NAT/DNAT, eBPF) and are not implemented in
// userspace mode, where they should not be called.
var errNotSupported = errors.New("not supported with userspace firewall")
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// peerRules is the canonical list-based storage for peer ACL rules.
// Match order is significant: drop rules come before accept rules so
// callers should consult the slice in order.
type peerRules []*PeerRule
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule
type routeRules []*RouteRule
type RouteRules []*RouteRule
func (r routeRules) Sort() {
func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b *RouteRule) int {
// Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
@@ -97,75 +86,22 @@ func (r routeRules) Sort() {
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
return 1
}
return strings.Compare(string(a.id), string(b.id))
return strings.Compare(a.id, b.id)
})
}
// peerRuleSpec carries the parameters that define a peer filter rule,
// threaded together through the build path so the builders take a single
// argument instead of a long parameter list.
type peerRuleSpec struct {
mgmtID []byte
sources []netip.Prefix
ipLayer gopacket.LayerType
matchAny bool
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
action firewall.Action
}
// Iface is the network interface the userspace firewall attaches to: the
// methods of the WireGuard device it actually uses.
type Iface interface {
Name() string
Address() wgaddr.Address
SetFilter(device.PacketFilter) error
GetWGDevice() *wgdevice.Device
}
// InterfaceAllower opens the NetBird interface in the host firewall so it
// doesn't drop traffic the userspace firewall handles, without taking over
// packet filtering. Implementations (nftables, iptables, firewalld, the windows
// netsh rule) are selected per platform and injected into Create; Apply runs at
// creation and Close on teardown.
type InterfaceAllower interface {
Apply() error
Close() error
}
// Config holds the dependencies and options for the userspace firewall.
type Config struct {
// IFace is the overlay interface the filter attaches to.
IFace Iface
// InterfaceAllower opens the NetBird interface in foreign kernel filter
// chains so the kernel doesn't drop traffic the userspace firewall handles.
// Nil in netstack mode, on non-Linux platforms without a backend, or when
// neither nftables nor iptables is available. firewalld trust is applied by
// the manager regardless, since firewalld owns its own chains and we cannot
// insert into them.
InterfaceAllower InterfaceAllower
// DisableServerRoutes indicates whether server routes are disabled.
DisableServerRoutes bool
FlowLogger nftypes.FlowLogger
MTU uint16
}
// Manager userspace firewall manager
type Manager struct {
decoders sync.Pool
wgIface Iface
ifaceAllower InterfaceAllower
mutex sync.RWMutex
outgoingRules map[netip.Addr]RuleSet
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
incomingDenyRules peerRules
incomingAcceptRules peerRules
incomingDenyIndex peerRuleIndex
incomingAcceptIndex peerRuleIndex
peerRulesMap map[nbid.RuleID]*PeerRule
routeRules routeRules
routeRulesMap map[nbid.RuleID]*RouteRule
mutex sync.RWMutex
// indicates whether server routes are disabled
disableServerRoutes bool
@@ -247,6 +183,24 @@ func (d *decoder) decodePacket(data []byte) error {
}
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
}
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
return mgr, nil
}
func parseCreateEnv() (bool, bool, bool) {
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
var err error
@@ -277,7 +231,7 @@ func parseCreateEnv() (bool, bool, bool) {
return disableConntrack, enableLocalForwarding, disableMSSClamping
}
func Create(cfg Config) (_ *Manager, err error) {
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
m := &Manager{
@@ -300,131 +254,62 @@ func Create(cfg Config) (_ *Manager, err error) {
return d
},
},
wgIface: cfg.IFace,
ifaceAllower: cfg.InterfaceAllower,
nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet),
incomingDenyRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: cfg.DisableServerRoutes,
disableServerRoutes: disableServerRoutes,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: cfg.FlowLogger,
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
mtu: cfg.MTU,
mtu: mtu,
}
m.routingEnabled.Store(false)
// Release the allower (and its monitor) if setup fails after it was wired in.
defer func() {
if err != nil {
m.closeAllowerOnError()
}
}()
if !disableMSSClamping {
m.enableMSSClamping(cfg.MTU)
m.mssClampEnabled = true
if mtu > ipv4TCPHeaderMinSize {
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
}
if mtu > ipv6TCPHeaderMinSize {
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
}
}
if err := m.localipmanager.UpdateLocalIPs(cfg.IFace); err != nil {
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
}
m.setupConntrack(disableConntrack)
if disableConntrack {
log.Info("conntrack is disabled")
} else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
}
if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err)
}
}
if err := cfg.IFace.SetFilter(m); err != nil {
if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err)
}
m.openHostFirewall(cfg.IFace.Name())
return m, nil
}
// closeAllowerOnError releases the allower (and its monitor) when Create fails
// after the allower was wired in.
func (m *Manager) closeAllowerOnError() {
if m.ifaceAllower == nil {
return
}
if err := m.ifaceAllower.Close(); err != nil {
log.Warnf("close interface allower after failed firewall setup: %v", err)
}
}
// enableMSSClamping enables MSS clamping and computes the per-family clamp values.
func (m *Manager) enableMSSClamping(mtu uint16) {
m.mssClampEnabled = true
if mtu > ipv4TCPHeaderMinSize {
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
}
if mtu > ipv6TCPHeaderMinSize {
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
}
}
// setupConntrack initializes the stateful trackers unless conntrack is disabled.
func (m *Manager) setupConntrack(disabled bool) {
if disabled {
log.Info("conntrack is disabled")
return
}
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
// openHostFirewall opens the NetBird interface in the kernel firewall so it
// doesn't drop traffic the userspace firewall handles. Best-effort: failures
// here shouldn't prevent the firewall from coming up.
func (m *Manager) openHostFirewall(ifaceName string) {
if m.ifaceAllower != nil {
if err := m.ifaceAllower.Apply(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
}
// firewalld owns its own chains we can't insert into, so trust the interface
// there in addition to the allower. Netstack has no kernel interface.
if !m.netstack {
if err := firewalld.TrustInterface(ifaceName); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
}
}
// Close cleans up the firewall manager: removes rules, closes trackers, and
// closes the interface allower.
func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
var merr *multierror.Error
if m.ifaceAllower != nil {
if err := m.ifaceAllower.Close(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("close interface allower: %w", err))
}
}
if !m.netstack {
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
merr = multierror.Append(merr, fmt.Errorf("untrust interface in firewalld: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// blockInvalidRouted installs drop rules for traffic to the wg overlay that
// arrives via the routing path. v4 and v6 are independent: a v6 install
// failure leaves v4 protection in place (and vice versa) so the returned
// slice always contains whatever was successfully installed, even on error.
// Callers must persist the slice so DisableRouting can clean partial state.
func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) {
wgPrefix := iface.Address().Network
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
@@ -435,7 +320,7 @@ func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
}
var rules []firewall.Rule
v4Rule, err := m.addRouteRule(
v4Rule, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: wgPrefix},
@@ -451,7 +336,7 @@ func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
if v6Net.IsValid() {
log.Debugf("blocking invalid routed traffic for %s", v6Net)
v6Rule, err := m.addRouteRule(
v6Rule, err := m.addRouteFiltering(
nil,
sources,
firewall.Network{Prefix: v6Net},
@@ -472,14 +357,20 @@ func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
}
func (m *Manager) determineRouting() error {
var disableUspRouting bool
var disableUspRouting, forceUserspaceRouter bool
var err error
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
var err error
disableUspRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
}
}
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
forceUserspaceRouter, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
}
}
switch {
case disableUspRouting:
@@ -494,11 +385,26 @@ func (m *Manager) determineRouting() error {
log.Info("server routes are disabled")
case forceUserspaceRouter:
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil:
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
log.Info("native routing is enabled")
default:
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
log.Info("userspace routing enabled")
log.Info("userspace routing enabled by default")
}
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
@@ -564,147 +470,82 @@ func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
}
// userspace routed packets are always SNATed to the inbound direction
// TODO: implement outbound SNAT
return nil
}
// RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair)
}
return nil
}
// addPeerRule installs an input-chain rule that matches packets
// by source only. Called from AddFilterRule when the caller doesn't
// specify a destination. Mixed-family inputs are split: each family
// gets its own rule with a family-correct ipLayer so packet decoding
// matches what the matcher expects.
func (m *Manager) addPeerRule(
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
sources []netip.Prefix,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
_ string,
) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{
id: uuid.New().String(),
mgmtId: id,
ip: i,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
}
if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
r.matchByIP = false
}
r.sPort = sPort
r.dPort = dPort
r.protoLayer = protoToLayer(proto, r.ipLayer)
m.mutex.Lock()
defer m.mutex.Unlock()
if sourcesMatchAny(sources) {
spec := peerRuleSpec{
mgmtID: id,
sources: sources,
ipLayer: layerTypeAll,
matchAny: true,
proto: proto,
sPort: sPort,
dPort: dPort,
action: action,
}
return m.addOnePeerRule(spec), nil
}
// Sources are a single family; normalize v4-mapped prefixes to plain
// v4 and pick the matching IP layer.
normalized := make([]netip.Prefix, len(sources))
ipLayer := layers.LayerTypeIPv4
for i, p := range sources {
normalized[i] = firewall.UnmapPrefix(p)
if normalized[i].Addr().Is6() {
ipLayer = layers.LayerTypeIPv6
}
}
spec := peerRuleSpec{
mgmtID: id,
sources: normalized,
ipLayer: ipLayer,
matchAny: false,
proto: proto,
sPort: sPort,
dPort: dPort,
action: action,
}
return m.addOnePeerRule(spec), nil
}
// addOnePeerRule builds and registers a single-family peer rule, or
// returns the existing rule when one with the same content key is
// already installed. The caller must hold m.mutex. The content key is
// the shared GenerateRuleID with an empty destination, so peer
// rules dedup the same way route rules and the kernel backends do.
//
// There is no refcount: a content key is installed once and deleted on
// the first DeleteFilterRule for that key. The caller must therefore
// key its own tracking by the returned rule id so add and delete stay
// balanced per content key; the acl manager does this via
// peerRulesPairs. The content key is order-independent, so callers
// passing the same sources in any order dedup to one rule.
func (m *Manager) addOnePeerRule(spec peerRuleSpec) *PeerRule {
ruleID := nbid.GenerateRuleID(spec.sources, firewall.Network{}, spec.proto, spec.sPort, spec.dPort, spec.action)
if existing, ok := m.peerRulesMap[ruleID]; ok {
return existing
}
rule := m.buildPeerRule(ruleID, spec)
m.registerPeerRule(rule)
return rule
}
func (m *Manager) buildPeerRule(ruleID nbid.RuleID, spec peerRuleSpec) *PeerRule {
r := &PeerRule{
id: ruleID,
mgmtId: spec.mgmtID,
sources: spec.sources,
matchAny: spec.matchAny,
action: spec.action,
srcPort: spec.sPort,
dstPort: spec.dPort,
}
if !spec.matchAny {
r.sourceAddrs = make(map[netip.Addr]struct{}, len(spec.sources))
for _, p := range spec.sources {
if p.Bits() == p.Addr().BitLen() {
r.sourceAddrs[p.Addr()] = struct{}{}
}
}
}
r.protoLayer = protoToLayer(spec.proto, spec.ipLayer)
return r
}
// registerPeerRule records a freshly built peer rule in the matching
// slice, index, and dedup map. The caller must hold m.mutex.
func (m *Manager) registerPeerRule(r *PeerRule) {
if r.action == firewall.ActionDrop {
m.incomingDenyRules = append(m.incomingDenyRules, r)
m.incomingDenyIndex.add(r)
var targetMap map[netip.Addr]RuleSet
if r.drop {
targetMap = m.incomingDenyRules
} else {
m.incomingAcceptRules = append(m.incomingAcceptRules, r)
m.incomingAcceptIndex.add(r)
targetMap = m.incomingRules
}
m.peerRulesMap[r.id] = r
if _, ok := targetMap[r.ip]; !ok {
targetMap[r.ip] = make(RuleSet)
}
targetMap[r.ip][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
}
// sourcesMatchAny reports whether the source list matches every source,
// i.e. contains an explicit /0 prefix. An empty list does not qualify:
// AddFilterRule rejects it with ErrNoSources, so "match any" is always
// the deliberate /0 case.
func sourcesMatchAny(sources []netip.Prefix) bool {
for _, p := range sources {
if p.Bits() == 0 {
return true
}
}
return false
}
// AddFilterRule is the unified entry point for both peer (input chain)
// and route (forward chain) filtering rules. The destination
// distinguishes the two semantics: a zero Network installs an
// input-side rule that matches by source only; a set Network installs
// a forward-side rule that also matches the destination.
func (m *Manager) AddFilterRule(
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
@@ -712,34 +553,13 @@ func (m *Manager) AddFilterRule(
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
if destination.IsZero() {
return m.addPeerRule(id, sources, proto, sPort, dPort, action)
}
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteRule(id, sources, destination, proto, sPort, dPort, action)
}
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
// is used to route to the correct internal path.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if r, ok := rule.(*PeerRule); ok {
return m.deletePeerRuleLocked(r)
}
// Anything else is a route rule (matched on the forward path).
return m.deleteRouteRule(rule)
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) addRouteRule(
func (m *Manager) addRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
@@ -747,14 +567,19 @@ func (m *Manager) addRouteRule(
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
if existingRule, ok := m.routeRulesMap[ruleID]; ok {
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
rule := RouteRule{
id: ruleID,
// TODO: consolidate these IDs
id: string(ruleKey),
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -769,58 +594,78 @@ func (m *Manager) addRouteRule(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleID] = &rule
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
ruleID := rule.ID()
trimmed, _, ok := removeRuleByID(m.routeRules, ruleID)
if !ok {
return fmt.Errorf("route rule not found: %s", ruleID)
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule)
}
m.routeRules = trimmed
delete(m.routeRulesMap, ruleID)
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey)
})
if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
// deletePeerRuleLocked removes a peer rule from the matching slice,
// index, and dedup map. The caller must hold m.mutex.
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
target, index := &m.incomingAcceptRules, &m.incomingAcceptIndex
if r.action == firewall.ActionDrop {
target, index = &m.incomingDenyRules, &m.incomingDenyIndex
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
r, ok := rule.(*PeerRule)
if !ok {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
trimmed, stored, ok := removeRuleByID(*target, r.id)
if !ok {
var sourceMap map[netip.Addr]RuleSet
if r.drop {
sourceMap = m.incomingDenyRules
} else {
sourceMap = m.incomingRules
}
if ruleset, ok := sourceMap[r.ip]; ok {
if _, exists := ruleset[r.id]; !exists {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(ruleset, r.id)
if len(ruleset) == 0 {
delete(sourceMap, r.ip)
}
} else {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
*target = trimmed
index.remove(stored)
delete(m.peerRulesMap, r.id)
return nil
}
// removeRuleByID removes the first rule whose id matches ruleID from
// rules, preserving order. It returns the trimmed slice, the removed
// rule, and whether a match was found.
func removeRuleByID[S ~[]T, T firewall.Rule](rules S, ruleID firewall.RuleID) (S, T, bool) {
idx := slices.IndexFunc(rules, func(r T) bool { return r.ID() == ruleID })
var removed T
if idx < 0 {
return rules, removed, false
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
return nil
}
removed = rules[idx]
return slices.Delete(rules, idx, idx+1), removed, true
}
// SetLegacyManagement is a no-op for the userspace firewall: it only matters
// when an old management server can't send route firewall rules, which the
// userspace router doesn't rely on.
func (m *Manager) SetLegacyManagement(bool) error {
return nil
return m.nativeFirewall.SetLegacyManagement(isLegacy)
}
// Flush doesn't need to be implemented for this manager
@@ -829,11 +674,9 @@ func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
m.incomingDenyRules = m.incomingDenyRules[:0]
m.incomingAcceptRules = m.incomingAcceptRules[:0]
m.incomingDenyIndex.reset()
m.incomingAcceptIndex.reset()
clear(m.peerRulesMap)
clear(m.outgoingRules)
clear(m.incomingDenyRules)
clear(m.incomingRules)
clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
@@ -865,15 +708,21 @@ func (m *Manager) resetState() {
}
}
// SetupEBPFProxyNoTrack is not supported by the userspace firewall: eBPF isn't
// used in userspace mode, so this should never be called.
func (m *Manager) SetupEBPFProxyNoTrack(uint16, uint16) error {
return errNotSupported
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.UpdateSet(set, prefixes)
}
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -971,11 +820,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
case layers.LayerTypeIPv4:
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src.Unmap(), dst.Unmap()
return src, dst
case layers.LayerTypeIPv6:
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src.Unmap(), dst.Unmap()
return src, dst
default:
return netip.Addr{}, netip.Addr{}
}
@@ -1555,12 +1404,20 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
return nil, false
}
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
return nil, true
}
@@ -1581,6 +1438,39 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false
}
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue
}
if rule.protoLayer == layerTypeAll {
return rule.mgmtId, rule.drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
continue
}
switch payloadLayer {
case layers.LayerTypeTCP:
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.mgmtId, rule.drop, true
}
}
return nil, false, false
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()
@@ -1657,13 +1547,10 @@ func (m *Manager) EnableRouting() error {
}
rules, err := m.blockInvalidRouted(m.wgIface)
// Persist whatever was installed even on partial failure, so DisableRouting
// can clean it up later.
m.blockRules = rules
if err != nil {
// Roll back so forwarding can't stay active without the full set of
// block rules.
if derr := m.disableRouting(); derr != nil {
log.Warnf("roll back routing after block rule failure: %v", derr)
}
return fmt.Errorf("block invalid routed: %w", err)
}
@@ -1674,10 +1561,6 @@ func (m *Manager) DisableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.disableRouting()
}
func (m *Manager) disableRouting() error {
fwder := m.forwarder.Load()
if fwder == nil {
return nil

View File

@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -114,13 +114,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddFilterRule(
_, err := m.AddPeerFiltering(
nil,
pfx(ip), fw.Network{},
ip,
fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}},
fw.ActionAccept)
fw.ActionAccept,
"",
)
require.NoError(b, err)
}
},
@@ -131,13 +133,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddFilterRule(
_, err := m.AddPeerFiltering(
nil,
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
net.ParseIP("0.0.0.0"),
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop)
fw.ActionDrop,
"",
)
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@@ -164,12 +168,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
}
// Create manager and basic setup
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -207,12 +208,9 @@ func BenchmarkStateScaling(b *testing.B) {
for _, count := range connCounts {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -253,12 +251,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -414,12 +409,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -544,12 +536,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -557,7 +546,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -630,12 +619,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -643,7 +629,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -744,19 +730,16 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -827,18 +810,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
if sc.rules {
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err)
}
@@ -951,7 +931,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, r := range rules {
dst := fw.Network{Prefix: r.dest}
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}
@@ -1034,11 +1014,9 @@ func BenchmarkMSSClamping(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -1101,11 +1079,9 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -1158,11 +1134,9 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))

View File

@@ -32,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NotNil(t, manager)
@@ -496,32 +496,40 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule for the same IP to test drop rule precedence
rules, err := manager.AddFilterRule(
rules, err := manager.AddPeerFiltering(
nil,
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
net.ParseIP(tc.ruleIP),
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept)
fw.ActionAccept,
"",
)
require.NoError(t, err)
require.NotNil(t, rules)
require.NotEmpty(t, rules)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rules))
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddFilterRule(
rules, err := manager.AddPeerFiltering(
nil,
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
net.ParseIP(tc.ruleIP),
tc.ruleProto,
tc.ruleSrcPort,
tc.ruleDstPort,
tc.ruleAction)
tc.ruleAction,
"",
)
require.NoError(t, err)
require.NotNil(t, rules)
require.NotEmpty(t, rules)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rules))
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
@@ -549,7 +557,7 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -664,18 +672,22 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rules))
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
require.NoError(t, err)
require.NotNil(t, rules)
require.NotEmpty(t, rules)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rules))
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
@@ -788,7 +800,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager)
@@ -1393,7 +1405,7 @@ func TestRouteACLFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
@@ -1403,13 +1415,13 @@ func TestRouteACLFiltering(t *testing.T) {
fw.ActionAccept,
)
require.NoError(t, err)
require.NotEmpty(t, rule)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rule))
require.NoError(t, manager.DeleteRouteRule(rule))
})
}
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
tc.rule.sources,
tc.rule.dest,
@@ -1419,10 +1431,10 @@ func TestRouteACLFiltering(t *testing.T) {
tc.rule.action,
)
require.NoError(t, err)
require.NotEmpty(t, rule)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteFilterRule(rule))
require.NoError(t, manager.DeleteRouteRule(rule))
})
srcIP := netip.MustParseAddr(tc.srcIP)
@@ -1590,9 +1602,9 @@ func TestRouteACLOrder(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var addedRules []fw.Rule
var rules []fw.Rule
for _, r := range tc.rules {
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
r.sources,
r.dest,
@@ -1603,12 +1615,12 @@ func TestRouteACLOrder(t *testing.T) {
)
require.NoError(t, err)
require.NotNil(t, rule)
addedRules = append(addedRules, rule)
rules = append(rules, rule)
}
t.Cleanup(func() {
for _, rule := range addedRules {
require.NoError(t, manager.DeleteFilterRule(rule))
for _, rule := range rules {
require.NoError(t, manager.DeleteRouteRule(rule))
}
})
@@ -1634,7 +1646,7 @@ func TestRouteACLSet(t *testing.T) {
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -1643,7 +1655,7 @@ func TestRouteACLSet(t *testing.T) {
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -1677,7 +1689,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
manager := setupRoutedManager(t, "10.10.0.100/16")
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
_, err := manager.AddFilterRule(
_, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: v6Dst},
@@ -1688,7 +1700,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
)
require.NoError(t, err)
_, err = manager.AddFilterRule(
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},

View File

@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddFilterRule(
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddFilterRule(
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddFilterRule(
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddFilterRule(
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddFilterRule(
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddFilterRule(
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -147,7 +147,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteFilterRule(rule1)
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
@@ -156,59 +156,6 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedDualStack verifies that when the interface has an
// IPv6 overlay address, blockInvalidRouted installs a block rule for both
// the v4 and v6 WG prefixes and that routed traffic to the v6 prefix is
// denied. The v4-only soft-skip path is covered by
// TestBlockInvalidRoutedIdempotent, whose mock leaves IPv6Net invalid.
func TestBlockInvalidRoutedDualStack(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
wgNet6 := netip.MustParsePrefix("fd00:1234::1/64")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
IPv6: wgNet6.Addr(),
IPv6Net: wgNet6,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
rules, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.Len(t, rules, 2, "dual-stack interface must produce a v4 and a v6 block rule")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount, "should have one block rule per family")
// v6 routed traffic to the WG prefix must be denied.
srcIP := netip.MustParseAddr("2001:db8::1")
dstIP := netip.MustParseAddr("fd00:1234::50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "block rule should deny routed traffic to the v6 WG prefix")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
@@ -235,7 +182,7 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -298,7 +245,7 @@ func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -327,7 +274,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -357,7 +304,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddFilterRule(
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -368,7 +315,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
)
require.NoError(t, err)
rule2, err := manager.AddFilterRule(
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
@@ -382,7 +329,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteFilterRule(rule1)
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
@@ -417,7 +364,7 @@ func setupTestManager(t *testing.T) *Manager {
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())

View File

@@ -78,19 +78,18 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
}
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
if m == nil {
t.Error("Manager is nil")
}
}
func TestManagerAddFilterRule(t *testing.T) {
func TestManagerAddPeerFiltering(t *testing.T) {
isSetFilterCalled := false
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error {
@@ -99,19 +98,18 @@ func TestManagerAddFilterRule(t *testing.T) {
},
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
}
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -133,47 +131,74 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
}
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
ip := netip.MustParseAddr("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
peerRule, ok := rule2.(*PeerRule)
require.True(t, ok, "rule should be a PeerRule")
inMap := func() bool {
if peerRule.action == fw.ActionDrop {
return findRuleByID(m.incomingDenyRules, ip, rule2.ID())
// Check rules exist in appropriate maps
for _, r := range rule2 {
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule exists in deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if !found {
t.Errorf("rule2 is not in the expected rules map")
}
return findRuleByID(m.incomingAcceptRules, ip, rule2.ID())
}
require.True(t, inMap(), "rule2 should be in the expected rules list")
for _, r := range rule2 {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
require.NoError(t, m.DeleteFilterRule(rule2), "failed to delete rule")
require.False(t, inMap(), "rule2 should be removed from the rules list")
// Check rules are removed from appropriate maps
for _, r := range rule2 {
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule is removed from deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if found {
t.Errorf("rule2 should be removed from the rules map")
}
}
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -195,11 +220,9 @@ func TestSetUDPPacketHook(t *testing.T) {
}
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
@@ -227,7 +250,7 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -237,34 +260,36 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeleteFilterRule(rule1)
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeleteFilterRule(rule2)
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules should be cleaned up when empty")
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
@@ -274,7 +299,7 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -286,21 +311,27 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
require.NoError(t, m.DeleteFilterRule(rules))
require.NoError(t, m.DeleteFilterRule(allowRules))
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
@@ -314,7 +345,7 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
@@ -323,39 +354,41 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeleteFilterRule(allowRule)
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeleteFilterRule(denyRule)
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
allowExists := countRulesForAddr(m.incomingAcceptRules, addr) > 0
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
@@ -367,7 +400,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -378,7 +411,7 @@ func TestManagerReset(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -390,7 +423,7 @@ func TestManagerReset(t *testing.T) {
return
}
if len(m.incomingAcceptRules) != 0 || len(m.incomingDenyRules) != 0 {
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
t.Errorf("rules are not empty")
}
}
@@ -406,7 +439,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -416,7 +449,7 @@ func TestNotMatchByIP(t *testing.T) {
proto := fw.ProtocolUDP
action := fw.ActionAccept
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -469,7 +502,7 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(Config{IFace: iface, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
@@ -486,11 +519,9 @@ func TestRemovePacketHook(t *testing.T) {
}
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close()
@@ -575,7 +606,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -590,7 +621,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
}
@@ -600,11 +631,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
}
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close() // Close the existing tracker
@@ -816,7 +845,7 @@ func TestUpdateSetMerge(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -829,7 +858,7 @@ func TestUpdateSetMerge(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -902,7 +931,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -910,7 +939,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddFilterRule(
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -1022,7 +1051,7 @@ func TestMSSClamping(t *testing.T) {
},
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: 1280})
manager, err := Create(ifaceMock, false, flowLogger, 1280)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -1214,7 +1243,7 @@ func TestShouldForward(t *testing.T) {
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
}
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -1329,7 +1358,7 @@ func TestShouldForward(t *testing.T) {
// Re-create manager to pick up the new address with IPv6
require.NoError(t, manager.Close(nil))
manager, err = Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
v6Cases := []struct {

View File

@@ -10,7 +10,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -21,9 +20,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -34,12 +33,6 @@ const (
iosMaxInFlight = 256
)
// IFace provides the WireGuard device and overlay addresses the forwarder needs.
type IFace interface {
GetWGDevice() *wgdevice.Device
Address() wgaddr.Address
}
type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
@@ -58,7 +51,7 @@ type Forwarder struct {
pingSemaphore chan struct{}
}
func New(iface IFace, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,

View File

@@ -362,10 +362,6 @@ func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload
return 0
}
if pc := f.endpoint.capture.Load(); pc != nil {
(*pc).Offer(fullPacket, true)
}
return len(fullPacket)
}

View File

@@ -7,6 +7,8 @@ import (
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
@@ -58,7 +60,7 @@ func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresse
}
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
func (m *localIPManager) UpdateLocalIPs(iface Iface) (err error) {
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)

View File

@@ -487,13 +487,19 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
}
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
func (m *Manager) AddDNATRule(firewall.ForwardRule) (firewall.Rule, error) {
return nil, errNotSupported
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes outbound DNAT rule.
func (m *Manager) DeleteDNATRule(firewall.Rule) error {
return errNotSupported
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// addPortRedirection adds a port redirection rule.
@@ -515,6 +521,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -560,16 +567,20 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
}
// AddOutputDNAT is not supported by the userspace firewall: it backs kernel DNS
// redirection, but userspace DNS is served in-process on the gVisor netstack, so
// this should never be called.
func (m *Manager) AddOutputDNAT(netip.Addr, firewall.Protocol, uint16, uint16) error {
return errNotSupported
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT is a no-op for the userspace firewall (see AddOutputDNAT).
func (m *Manager) RemoveOutputDNAT(netip.Addr, firewall.Protocol, uint16, uint16) error {
return nil
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.

View File

@@ -64,11 +64,9 @@ func BenchmarkDNATTranslation(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -126,11 +124,9 @@ func BenchmarkDNATTranslation(b *testing.B) {
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -200,11 +196,9 @@ func BenchmarkDNATScaling(b *testing.B) {
for _, count := range mappingCounts {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -314,11 +308,9 @@ func BenchmarkChecksumUpdate(b *testing.B) {
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -489,11 +481,9 @@ func BenchmarkPortDNAT(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))

View File

@@ -13,11 +13,9 @@ import (
// TestPortDNATBasic tests basic port DNAT functionality
func TestPortDNATBasic(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -51,11 +49,9 @@ func TestPortDNATBasic(t *testing.T) {
// TestPortDNATMultipleRules tests multiple port DNAT rules
func TestPortDNATMultipleRules(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))

View File

@@ -15,11 +15,9 @@ import (
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -106,11 +104,9 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
// TestDNATMappingManagement tests adding/removing DNAT mappings
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -156,11 +152,9 @@ func TestDNATMappingManagement(t *testing.T) {
}
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -208,11 +202,9 @@ func TestInboundPortDNAT(t *testing.T) {
}
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))

View File

@@ -1,333 +0,0 @@
//go:build uspbench
package uspfilter
import (
"fmt"
"io"
"math/rand"
"net"
"net/netip"
"runtime"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
// rules, each with K source peers in its set.
//
// With the reverse-source index, miss cost is independent of M and
// hit cost grows only with the number of rules touching a single
// srcIP, not with total rule count.
func BenchmarkPeerACLMatch(b *testing.B) {
shapes := []struct{ M, K int }{
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
}
families := []struct {
name string
v6 bool
}{{"v4", false}, {"v6", true}}
for _, fam := range families {
for _, s := range shapes {
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
runPeerACLBench(b, s.M, s.K, true, fam.v6)
})
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
runPeerACLBench(b, s.M, s.K, false, fam.v6)
})
}
}
}
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
// Miss packets are dropped, so they always traverse the full peer
// ACL matcher (every bucket) without short-circuiting and without
// touching conntrack. Disable conntrack for the miss case so it
// measures the matcher, not established-state lookups. The hit case
// keeps conntrack on: an accepted packet reaches trackInbound, which
// needs the trackers conntrack creates.
if !hit {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
}
bits := 32
genPkt := generatePacket
addrs := uniqueAddrs
if v6 {
bits = 128
genPkt = generatePacket6
addrs = uniqueAddrs6
}
// dstIP must be a local IP so filterInbound takes the local-traffic
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
// address the manager doesn't own would be treated as routed and
// short-circuit before the peer matcher.
dstIP := addrs(1, 2)[0]
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
if v6 {
// The local-IP manager needs a valid v4 address too; expose the v6
// dst as the interface's IPv6 so IsLocalIP recognizes it.
mockAddr = wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: dstIP,
IPv6Net: netip.PrefixFrom(dstIP, bits),
}
}
manager, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { return mockAddr },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
// Generate M policies × K source peers, all distinct.
all := addrs(m*k, 1)
for i := 0; i < m; i++ {
sources := make([]netip.Prefix, k)
for j, a := range all[i*k : (i+1)*k] {
sources[j] = netip.PrefixFrom(a, bits)
}
_, err := manager.AddFilterRule(
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{uint16(80 + i)}},
fw.ActionAccept)
require.NoError(b, err)
}
// Hit: cycle through real sources, picking the matching policy's port.
// Miss: a source from a disjoint range, port 80 (matches no policy).
var pktFn func(i int) []byte
if hit {
pktFn = func(i int) []byte {
policy := i % m
src := all[policy*k+(i%k)]
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
}
} else {
miss := addrs(4096, 99)
pktFn = func(i int) []byte {
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
}
}
// Pre-build a pool to avoid allocations dominating the measurement.
pool := make([][]byte, 1024)
for i := range pool {
pool[i] = pktFn(i)
}
// Confirm the matcher is actually exercised: a hit packet must be
// allowed and a miss packet dropped. Without this the benchmark
// could silently time the routed early-return instead.
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
"benchmark must reach the peer ACL matcher")
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(pool[i%len(pool)], 0)
}
}
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
// the source-keyed index across realistic deployment shapes. Two
// dimensions matter: (M, K), the number of policies × peers-per-policy,
// and overlap, the fraction of peers shared between policies.
//
// The output uses ReportMetric("bytes/rule") so the cost can be
// compared across shapes directly. Total bytes = bytes/rule * M.
func BenchmarkPeerACLIndexMemory(b *testing.B) {
cases := []struct {
name string
M, K int
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
}{
{"M=10/K=100/disjoint", 10, 100, 0},
{"M=100/K=100/disjoint", 100, 100, 0},
{"M=100/K=1000/disjoint", 100, 1000, 0},
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
}
for _, c := range cases {
b.Run(c.name, func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mgr, err := Create(Config{
IFace: &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
},
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
require.NoError(b, err)
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
runtime.GC()
var ms runtime.MemStats
runtime.ReadMemStats(&ms)
before := ms.HeapAlloc
// Drop the manager's external roots so we can isolate
// the index cost. We hold the manager itself live; the
// index is what we measure on the second pass.
mgr.incomingAcceptIndex.reset()
mgr.incomingDenyIndex.reset()
mgr.incomingAcceptRules = mgr.incomingAcceptRules[:0]
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
runtime.GC()
runtime.ReadMemStats(&ms)
after := ms.HeapAlloc
delta := int64(before) - int64(after)
if delta < 0 {
delta = 0
}
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
b.ReportMetric(float64(delta), "bytes/total")
require.NoError(b, mgr.Close(nil))
}
})
}
}
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
b.Helper()
pool := uniqueAddrs(k+m*k, 1) // big enough universe
sharedLen := int(float64(k) * overlapFrac)
if sharedLen > k {
sharedLen = k
}
shared := pool[:sharedLen]
uniquePool := pool[sharedLen:]
for i := 0; i < m; i++ {
sources := make([]netip.Prefix, 0, k)
for _, a := range shared {
sources = append(sources, netip.PrefixFrom(a, 32))
}
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
for _, a := range unique {
sources = append(sources, netip.PrefixFrom(a, 32))
}
_, err := mgr.AddFilterRule(
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{uint16(80 + i)}},
fw.ActionAccept)
require.NoError(b, err)
}
}
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
// policy sources / dst; seed 99 puts misses in 10/8.
func uniqueAddrs(n int, seed int64) []netip.Addr {
out := make([]netip.Addr, 0, n)
seen := make(map[netip.Addr]struct{}, n)
r := rand.New(rand.NewSource(seed))
miss := seed == 99
for len(out) < n {
var b [4]byte
if miss {
b[0] = 10
b[1] = byte(r.Intn(256))
} else {
b[0] = 100
b[1] = byte(64 + r.Intn(63))
}
b[2] = byte(r.Intn(256))
b[3] = byte(1 + r.Intn(254))
a := netip.AddrFrom4(b)
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
out = append(out, a)
}
return out
}
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
// disjoint from any source.
func uniqueAddrs6(n int, seed int64) []netip.Addr {
out := make([]netip.Addr, 0, n)
seen := make(map[netip.Addr]struct{}, n)
r := rand.New(rand.NewSource(seed))
miss := seed == 99
for len(out) < n {
var b [16]byte
if miss {
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
} else {
b[0] = 0xfd
}
for x := 8; x < 16; x++ {
b[x] = byte(r.Intn(256))
}
a := netip.AddrFrom16(b)
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
out = append(out, a)
}
return out
}
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
// generatePacket for the v4 case.
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
b.Helper()
ipv6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: protocol,
SrcIP: srcIP,
DstIP: dstIP,
}
var transportLayer gopacket.SerializableLayer
switch protocol {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
transportLayer = udp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
return buf.Bytes()
}

View File

@@ -1,150 +0,0 @@
package uspfilter
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
func newTestManager(t *testing.T) *Manager {
t.Helper()
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
require.NoError(t, err, "create manager")
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
return m
}
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
// the same peer rule twice does not create two backing rules. The acl
// manager keys its own cache, but the firewall backend must be
// idempotent on its own so a double-apply cannot leak rules, matching
// the route path and the kernel backends.
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "first add")
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "second add")
assert.Equal(t, first.ID(), second.ID(), "duplicate add should return the same rule id")
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
}
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
// backend's no-refcount contract: a content key installed twice is
// still one rule, and the first DeleteFilterRule removes it. The
// backend does not refcount, so balance is the caller's job (it keys
// its tracking by the returned id and deletes once per key). If this
// ever silently grew a refcount, the acl manager's delete accounting
// would diverge from the kernel.
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "first add")
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "second add")
require.Equal(t, first.ID(), second.ID(), "dedup to one rule")
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
require.NoError(t, m.DeleteFilterRule(first), "delete once")
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
assert.NotContains(t, m.peerRulesMap, first.ID(), "dedup map entry cleared")
}
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
// content hash, not a random UUID: identical inputs produce the same id
// across independent managers. A random id breaks caller-side dedup.
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
ip := net.ParseIP("10.0.0.5")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
m1 := newTestManager(t)
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
m2 := newTestManager(t)
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
assert.Equal(t, r1.ID(), r2.ID(), "same inputs must produce the same rule id")
}
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
// differing only by port are kept separate.
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
action := fw.ActionAccept
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
require.NoError(t, err)
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
require.NoError(t, err)
assert.NotEqual(t, r80.ID(), r443.ID(), "different ports must produce different rule ids")
assert.Len(t, m.incomingAcceptRules, 2, "distinct rules must both be stored")
}
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
// matching on source port and one matching on destination port for the
// same selector do not collide: the port lands in a different slot, so
// the content key must differ.
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
require.NoError(t, err)
assert.NotEqual(t, dPortRule.ID(), sPortRule.ID(), "source-port and dest-port matches must produce different rule ids")
}
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
// list is rejected rather than treated as "match any". "Match any" must
// be an explicit /0, so a zeroed list can never silently widen a rule to
// every source.
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
m := newTestManager(t)
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
assert.Empty(t, m.incomingAcceptRules, "no rule should be stored for empty sources")
}

View File

@@ -1,105 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func newV6TestManager(t *testing.T, localV6 string) *Manager {
t.Helper()
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IPv6: netip.MustParseAddr(localV6),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
require.NoError(t, err, "create manager")
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
return m
}
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
t.Helper()
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
SrcIP: net.ParseIP(src),
DstIP: net.ParseIP(dst),
}
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
return buf.Bytes()
}
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
// rules: a matching v6 source is accepted, a non-matching one is
// denied by the default. This is the end-to-end proof that the index
// is not v4-only.
func TestPeerACL_IPv6HostRule(t *testing.T) {
m := newV6TestManager(t, "fd00::100")
src := net.ParseIP("fd00::1")
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
require.NoError(t, err, "add v6 accept rule")
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
"v6 packet from the allowed /128 source must be accepted")
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
"v6 packet from an unlisted source must be denied by default")
}
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
// right index bucket: a /128 in bySource keyed by its address, and
// coarser prefixes (including ::/0) in the nonHost slice.
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
m := newV6TestManager(t, "fd00::100")
port := &fw.Port{Values: []uint16{53}}
host := netip.MustParseAddr("fd00::1")
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, m.incomingAcceptIndex.nonHost, 1, "coarser v6 prefix must land in nonHost")
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, m.incomingAcceptIndex.nonHost, 2, "::/0 source must also land in nonHost")
}
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
// source prefix is normalized to v4 so a plain v4 packet matches it.
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
m := newTestManager(t)
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
require.NoError(t, err)
v4 := netip.MustParseAddr("192.168.1.1")
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
}

View File

@@ -1,139 +0,0 @@
package uspfilter
import (
"net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// peerRuleIndex is the source-side dispatcher consulted on the packet
// hot path. It splits rules into two buckets by the shape of their
// source list:
//
// - bySource: every source is a host prefix (/32 for v4, /128 for
// v6). Keyed by the concrete source address, so a hit guarantees
// the source filter passes and the matcher goes straight to
// proto/port checks. This is the common case for peer ACLs.
// - nonHost: any source list with a prefix coarser than a host,
// including a /0 "match any". Walked linearly with a per-rule
// Contains() check. Expected small or empty for typical peer ACLs.
//
// Maintained incrementally by add/remove, never rebuilt.
type peerRuleIndex struct {
bySource map[netip.Addr][]*PeerRule
nonHost []*PeerRule
}
func (i *peerRuleIndex) add(r *PeerRule) {
if hasNonHostSource(r) {
i.nonHost = append(i.nonHost, r)
return
}
if i.bySource == nil {
i.bySource = make(map[netip.Addr][]*PeerRule)
}
for a := range r.sourceAddrs {
i.bySource[a] = append(i.bySource[a], r)
}
}
func (i *peerRuleIndex) remove(r *PeerRule) {
if hasNonHostSource(r) {
i.nonHost = slices.DeleteFunc(i.nonHost, eqRule(r))
return
}
if i.bySource == nil {
return
}
for a := range r.sourceAddrs {
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
if len(entries) == 0 {
delete(i.bySource, a)
} else {
i.bySource[a] = entries
}
}
}
func (i *peerRuleIndex) reset() {
i.bySource = nil
i.nonHost = i.nonHost[:0]
}
// match returns the first rule matching src and the decoded packet.
// Host rules are found by direct map lookup; nonHost rules need a
// per-rule source Contains() check, except match-any (/0) rules which
// apply to every source regardless of family (a v4 /0 also matches v6).
// Within either bucket the matcher runs the proto/port filter.
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range i.bySource[src] {
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
for _, rule := range i.nonHost {
if !rule.matchAny && !prefixesContain(rule.sources, src) {
continue
}
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
return nil, false, false
}
func eqRule(target *PeerRule) func(*PeerRule) bool {
return func(p *PeerRule) bool { return p == target }
}
// hasNonHostSource reports whether the rule has any source prefix
// that is not a single host address. Called only at add/remove time,
// not on the packet path.
func hasNonHostSource(r *PeerRule) bool {
for _, p := range r.sources {
if p.Bits() != p.Addr().BitLen() {
return true
}
}
return false
}
// matchProto applies the proto/port half of a rule against the
// decoded packet. Source matching is the caller's responsibility.
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
drop := rule.action == firewall.ActionDrop
if rule.protoLayer == layerTypeAll {
return rule.mgmtId, drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
return nil, false, false
}
switch payloadLayer {
case layers.LayerTypeTCP:
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, drop, true
}
case layers.LayerTypeUDP:
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.mgmtId, drop, true
}
return nil, false, false
}
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
for _, p := range sources {
if p.Contains(src) {
return true
}
}
return false
}

View File

@@ -10,49 +10,24 @@ import (
// PeerRule to handle management of rules
type PeerRule struct {
id firewall.RuleID
mgmtId []byte
// sources is the canonical list of source prefixes this rule
// matches against.
sources []netip.Prefix
// sourceAddrs is a fast-path membership set for host-prefix
// sources (/32 v4, /128 v6). Populated alongside sources;
// consulted before falling back to prefix scan.
sourceAddrs map[netip.Addr]struct{}
// matchAny is true when sources covers everything (0.0.0.0/0,
// ::/0). In that case neither sourceAddrs nor sources need to be
// consulted.
matchAny bool
id string
mgmtId []byte
ip netip.Addr
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// matchesSource reports whether the given source address is covered
// by this rule's source list.
func (r *PeerRule) matchesSource(src netip.Addr) bool {
if r.matchAny {
return true
}
if _, ok := r.sourceAddrs[src]; ok {
return true
}
for _, p := range r.sources {
if p.Contains(src) {
return true
}
}
return false
sPort *firewall.Port
dPort *firewall.Port
drop bool
}
// ID returns the rule id
func (r *PeerRule) ID() firewall.RuleID {
func (r *PeerRule) ID() string {
return r.id
}
type RouteRule struct {
id firewall.RuleID
id string
mgmtId []byte
sources []netip.Prefix
dstSet firewall.Set
@@ -64,6 +39,6 @@ type RouteRule struct {
}
// ID returns the rule id
func (r *RouteRule) ID() firewall.RuleID {
func (r *RouteRule) ID() string {
return r.id
}

View File

@@ -1,50 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// countRulesForAddr reports how many rules in the given slice match
// the supplied source address.
func countRulesForAddr(rules peerRules, src netip.Addr) int {
n := 0
for _, r := range rules {
if r.matchesSource(src) {
n++
}
}
return n
}
// findRuleByID returns true if the rules slice contains a rule with
// the given id whose source set covers src.
func findRuleByID(rules peerRules, src netip.Addr, id firewall.RuleID) bool {
for _, r := range rules {
if r.id == id && r.matchesSource(src) {
return true
}
}
return false
}
// pfx converts a single net.IP into the []netip.Prefix form
// AddFilterRule expects. A nil or unspecified address becomes a /0
// ("match any") prefix in the matching family; any other address
// becomes its /32 (or /128) host prefix.
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, _ := netip.AddrFromSlice(ip)
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -285,14 +285,6 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
trace.SourceIP = srcIP
trace.DestinationIP = dstIP
// A fragment or otherwise truncated packet has no transport layer.
// The inbound datapath drops these via isValidPacket; the tracer must
// guard explicitly since every downstream stage reads d.decoded[1].
if len(d.decoded) < 2 {
trace.AddResult(StageReceived, "Packet has no transport layer (fragment or unsupported protocol)", false)
return trace
}
// Determine protocol and ports
switch d.decoded[1] {
case layers.LayerTypeTCP:

View File

@@ -45,7 +45,7 @@ func TestTracePacket(t *testing.T) {
},
}
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
if !statefulMode {
@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {

View File

@@ -6,7 +6,7 @@
!define DESCRIPTION "Connect your devices into a secure WireGuard-based overlay network with SSO, MFA, and granular access controls."
!define INSTALLER_NAME "netbird-installer.exe"
!define MAIN_APP_EXE "Netbird"
!define ICON "ui\\assets\\netbird.ico"
!define ICON "ui\\build\\windows\\icon.ico"
!define BANNER "ui\\build\\banner.bmp"
!define LICENSE_DATA "..\\LICENSE"
@@ -280,6 +280,43 @@ CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
SectionEnd
# Install the Microsoft Edge WebView2 runtime if it isn't already present.
# Macro adapted from Wails3's NSIS template (wails_tools.nsh): a registry
# probe followed by a silent install of the embedded evergreen bootstrapper.
# The MicrosoftEdgeWebview2Setup.exe payload is staged next to this script
# by the sign-pipelines build step (`wails3 generate webview2bootstrapper`).
!macro nb.webview2runtime
SetRegView 64
# Per-machine install marker — populated when the runtime ships with
# Edge or has been installed by an admin previously.
ReadRegStr $0 HKLM "SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
${If} $0 != ""
Goto webview2_ok
${EndIf}
# Per-user fallback for HKCU installs.
ReadRegStr $0 HKCU "Software\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" "pv"
${If} $0 != ""
Goto webview2_ok
${EndIf}
SetDetailsPrint both
DetailPrint "Installing: WebView2 Runtime"
SetDetailsPrint listonly
InitPluginsDir
CreateDirectory "$pluginsdir\webview2bootstrapper"
SetOutPath "$pluginsdir\webview2bootstrapper"
File "MicrosoftEdgeWebview2Setup.exe"
ExecWait '"$pluginsdir\webview2bootstrapper\MicrosoftEdgeWebview2Setup.exe" /silent /install'
SetDetailsPrint both
webview2_ok:
!macroend
Section -WebView2
!insertmacro nb.webview2runtime
SectionEnd
Section -Post
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start'
@@ -326,9 +363,9 @@ DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll"
!if ${ARCH} == "amd64"
# Legacy: pre-Wails installs shipped opengl32.dll (Mesa3D for Fyne); remove
# any leftover copy on uninstall so old upgrades don't leave it behind.
Delete "$INSTDIR\opengl32.dll"
!endif
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR"

View File

@@ -1,190 +0,0 @@
package acl
import (
"net/netip"
"sync"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
// invariant: the backends classify a rule as a peer rule purely by the
// absence of a destination (neither prefix nor set). A default route
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
// a route, not collapse into the peer path.
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
}
// A zero-value Network is the only peer-rule shape.
var empty fwmgr.Network
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
assert.False(t, empty.IsSet(), "zero Network must not be a set")
}
// TestDetermineDestinationAlwaysRoute verifies determineDestination
// never yields an empty Network for a valid route rule: every branch
// (static prefix, default route, dynamic with/without domains, with and
// without a local resolver) produces a destination that classifies as a
// route. If this regresses, a route rule would be dispatched down the
// peer path, which matches on source only.
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
cases := []struct {
name string
rule *mgmProto.RouteFirewallRule
resolver bool
sources []netip.Prefix
}{
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
require.NoError(t, err)
assert.True(t, dest.IsPrefix() || dest.IsSet(),
"destination must classify as a route, got empty Network")
})
}
}
// countingFirewall wraps a real firewall.Manager and counts filter-rule
// add/delete calls so a test can assert how many backing rules the acl
// manager actually creates and tears down.
type countingFirewall struct {
fwmgr.Manager
mu sync.Mutex
addCalls int
dels int
ruleIDs map[fwmgr.RuleID]struct{}
}
// distinctRules returns the number of distinct backing rules the
// backend produced. Because the backend dedups identical content,
// repeated AddFilterRule calls for the same rule resolve to one id.
func (f *countingFirewall) distinctRules() int {
f.mu.Lock()
defer f.mu.Unlock()
return len(f.ruleIDs)
}
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
rule, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if err == nil {
f.mu.Lock()
f.addCalls++
if f.ruleIDs == nil {
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
}
if rule != nil {
f.ruleIDs[rule.ID()] = struct{}{}
}
f.mu.Unlock()
}
return rule, err
}
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
err := f.Manager.DeleteFilterRule(r)
if err == nil {
f.mu.Lock()
f.dels++
delete(f.ruleIDs, r.ID())
f.mu.Unlock()
}
return err
}
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
t.Helper()
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
fw := &countingFirewall{Manager: realFW}
cleanup := func() {
require.NoError(t, realFW.Close(nil))
ctrl.Finish()
}
return NewDefaultManager(fw), fw, cleanup
}
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
// the backends rely on: two policies that authorize an identical flow
// (same selector and sources) collapse to a single backing firewall
// rule, and that rule survives until BOTH policies are gone. This is
// why the backend can dedup on add without refcounting on delete: the
// acl manager's pair key matches the backend's content key, so add and
// delete stay balanced per content key across full-state reapplies.
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
acl, fw, cleanup := newCountingACL(t)
defer cleanup()
ruleA := &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
ruleB := &mgmProto.FirewallRule{
PolicyID: []byte("policy-B"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
// Both policies present: identical content collapses to one rule.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
// Drop policy A only: the shared rule is still authorized by B, so
// nothing is deleted.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Drop policy B too: now the content key has no authorizer and the
// single backing rule is removed exactly once.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
assert.Equal(t, 0, len(acl.peerRulesPairs))
}

View File

@@ -1,318 +0,0 @@
package acl
import (
"errors"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
// with identical selectors but different PolicyIDs do NOT get merged
// into one group, so each policy's sources merge under its own
// attribution id. (Identical-content groups may still dedup to one
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-B"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
}
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
// belonging to the same policy don't merge.
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-A"),
PeerIP: "2001:db8::1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules of different families must produce separate groups")
var sawV4, sawV6 bool
for _, g := range groups {
require.Len(t, g.sources, 1)
if g.sources[0].Addr().Is4() {
sawV4 = true
}
if g.sources[0].Addr().Is6() {
sawV6 = true
}
}
assert.True(t, sawV4 && sawV6)
}
// TestGroupPeerRulesSplitsMixedFamilySingleRule verifies that a single
// FirewallRule carrying both v4 and v6 source prefixes is split into one
// group per family. Each backend keys a rule to a single family, so a
// group whose sources span families would mismatch the other family's
// sources. mgmt normally emits one rule per family; this guards against
// a mixed-family rule slipping through.
func TestGroupPeerRulesSplitsMixedFamilySingleRule(t *testing.T) {
srcs := [][]byte{
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::1")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::2")),
}
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
SourcePrefixes: srcs,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "mixed-family sources in one rule must split into two groups")
for _, g := range groups {
require.Len(t, g.sources, 2)
v6 := prefixIsV6(g.sources[0])
for _, s := range g.sources {
assert.Equal(t, v6, prefixIsV6(s), "every source in a group must share one family")
}
}
}
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
// every distinguishing field (policy, family, direction, action,
// proto, port) collapse into a single multi-source group.
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
mk := func(peerIP string) *mgmProto.FirewallRule {
return &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: peerIP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
}
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 1)
require.Len(t, groups[0].sources, 3)
}
// TestGroupPeerRulesPortSeparates verifies that PortInfo is part of the
// selector key: rules differing only in port must not merge, and a
// single port must not merge with a range. A regression dropping the
// port from the key would collapse rules for different ports into one.
func TestGroupPeerRulesPortSeparates(t *testing.T) {
mkPort := func(peerIP string, port uint32) *mgmProto.FirewallRule {
return &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: peerIP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Port{Port: port}},
}
}
groups, denyErr, err := groupPeerRules([]*mgmProto.FirewallRule{
mkPort("10.0.0.1", 80), mkPort("10.0.0.2", 80), mkPort("10.0.0.3", 443),
})
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "rules on different ports must not merge")
rangeRule := &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Range_{Range: &mgmProto.PortInfo_Range{Start: 80, End: 90}}},
}
groups, denyErr, err = groupPeerRules([]*mgmProto.FirewallRule{mkPort("10.0.0.1", 80), rangeRule})
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2, "a single port and a range must not merge")
}
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
// new sourcePrefixes wire field is consumed and produces a
// multi-source group in one shot (no client-side merging needed).
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
srcs := [][]byte{
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
}
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
SourcePrefixes: srcs,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 1)
require.Len(t, groups[0].sources, 3)
}
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
// and drop rules with the same selector don't merge.
func TestGroupPeerRulesActionSeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, denyErr, err := groupPeerRules(rules)
require.NoError(t, denyErr)
require.NoError(t, err)
require.Len(t, groups, 2)
}
// failingDeleteFirewall wraps a real firewall.Manager and forces the
// next N DeleteFilterRule calls to fail. Used to verify that the acl
// manager retains rules whose deletion was rejected by the backend,
// so they get retried on the next ApplyFiltering pass instead of
// becoming orphans.
type failingDeleteFirewall struct {
fwmgr.Manager
failCount int
}
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
if f.failCount > 0 {
f.failCount--
return errors.New("simulated delete failure")
}
return f.Manager.DeleteFilterRule(r)
}
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
// transient DeleteFilterRule error doesn't make the acl manager
// forget about a rule. The rule must remain in peerRulesPairs so the
// next ApplyFiltering pass attempts the delete again.
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() { require.NoError(t, realFW.Close(nil)) }()
fw := &failingDeleteFirewall{Manager: realFW}
acl := NewDefaultManager(fw)
// First pass: install a rule.
netmap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(netmap1, false)
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
// Second pass: remove the rule from the map. The backend will
// fail the delete; the acl manager must retain the rule.
fw.failCount = 1
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
acl.ApplyFiltering(netmap2, false)
require.Equal(t, 1, len(acl.peerRulesPairs),
"rule must be retained when DeleteFilterRule fails so it gets retried")
// Third pass: same map, backend no longer fails. The rule
// should now succeed in being removed.
acl.ApplyFiltering(netmap2, false)
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
}

View File

@@ -5,18 +5,18 @@ import (
"encoding/hex"
"fmt"
"net/netip"
"slices"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
// RuleID aliases manager.RuleID so existing nbid.RuleID references
// keep working while the canonical type lives in the firewall package.
type RuleID = manager.RuleID
type RuleID string
// GenerateRuleID returns a deterministic content hash identifying a filter rule.
func GenerateRuleID(
func (r RuleID) ID() string {
return string(r)
}
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination manager.Network,
proto manager.Protocol,
@@ -24,7 +24,6 @@ func GenerateRuleID(
dPort *manager.Port,
action manager.Action,
) RuleID {
sources = slices.Clone(sources)
manager.SortPrefixes(sources)
h := sha256.New()

View File

@@ -1,6 +1,8 @@
package acl
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"net/netip"
@@ -21,10 +23,6 @@ import (
var ErrSourceRangesEmpty = errors.New("sources range is empty")
// ErrNoRuleReturned is returned when the firewall backend reports success
// from AddFilterRule but yields no rule to track.
var ErrNoRuleReturned = errors.New("backend returned no rule")
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
@@ -33,46 +31,17 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
}
// peerRuleGroup collapses a set of single-source FirewallRules sharing
// the same selector into one multi-source rule to push to the backend.
type peerRuleGroup struct {
direction mgmProto.RuleDirection
action mgmProto.RuleAction
protocol mgmProto.RuleProtocol
port *mgmProto.PortInfo
// legacyPort is used only when PortInfo is empty (old management).
legacyPort string
policyID []byte
sources []netip.Prefix
}
// peerRuleKey is the comparable selector that decides which single-source
// rules merge into one group. Rules with an equal key collapse into one
// multi-source backend rule. PortInfo is flattened into its scalar fields
// so the key compares by value; policyID keeps policies separate so two
// policies authorizing different peers don't merge under one attribution.
type peerRuleKey struct {
v6 bool
policyID string
direction mgmProto.RuleDirection
action mgmProto.RuleAction
protocol mgmProto.RuleProtocol
legacyPort string
port uint16
rangeStart uint16
rangeEnd uint16
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
firewall: fm,
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
routeRules: make(map[id.RuleID]firewall.Rule),
routeRules: make(map[id.RuleID]struct{}),
}
}
@@ -99,12 +68,10 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total)
}()
if err := d.applyPeerACLs(networkMap); err != nil {
log.Errorf("apply peer ACLs: %v", err)
}
d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("apply route ACLs: %v", err)
log.Errorf("Failed to apply route ACLs: %v", err)
}
if err := d.firewall.Flush(); err != nil {
@@ -112,7 +79,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
}
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules := networkMap.FirewallRules
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
@@ -135,158 +102,59 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
)
}
// Group incoming single-source rules from management by their
// (direction, action, proto, port) selector and merge sources.
// One call to the firewall backend per merged rule.
// A deny we cannot decode would leave its traffic unblocked, so skip
// the whole pass and keep existing rules until the next sync.
groups, denyErr, err := groupPeerRules(rules)
if denyErr != nil {
return fmt.Errorf("decode deny rule sources: %w", denyErr)
}
newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
// the missing deny. Currently we accumulate errors and continue.
var merr *multierror.Error
if err != nil {
merr = multierror.Append(merr, err)
}
// Apply denies first. A deny that fails to install is a security
// failure (fail-open), so if any deny errors we roll back the
// denies we already installed in this pass and bail out without
// installing any accept. Pre-existing rules stay untouched until
// the next successful pass clears them.
denies, accepts := splitDenyAccept(groups)
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
return fmt.Errorf("install deny rules: %w", err)
}
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
merr = multierror.Append(merr, err)
}
// Tear down rules that disappeared from the networkmap. Any rule
// the backend refuses to delete stays in our tracking so it gets
// retried on the next ApplyFiltering. Otherwise a transient
// delete failure would leak the rule in the firewall until the
// process exits.
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; ok {
continue
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
selector := d.getRuleGroupingSelector(r)
ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
d.ipsetCounter++
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
}
var remaining []firewall.Rule
for _, rule := range rules {
if err := d.firewall.DeleteFilterRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
remaining = append(remaining, rule)
}
}
if len(remaining) > 0 {
newRulePairs[pairID] = remaining
}
}
d.peerRulesPairs = newRulePairs
return nberrors.FormatErrorOrNil(merr)
}
// installPeerGroups applies each group and records the resulting rule
// pairs in newRulePairs. With atomic set (deny rules), a single failure
// rolls back every rule installed in this call and returns, leaving the
// firewall exactly as before: denies are fail-closed and must be applied
// all-or-nothing. With atomic unset (accept rules), failures are
// accumulated and the remaining groups still install, so one malformed
// allow cannot drop every other legitimate allow in the pass.
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
var freshlyInstalled []id.RuleID
var merr *multierror.Error
for _, g := range groups {
pairID, rulePair, err := d.applyPeerGroup(g)
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
if atomic {
d.rollbackInstalled(freshlyInstalled)
return fmt.Errorf("apply firewall rule: %w", err)
}
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
continue
}
if len(rulePair) == 0 {
continue
if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
if _, existed := d.peerRulesPairs[pairID]; !existed {
freshlyInstalled = append(freshlyInstalled, pairID)
}
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
var merr *multierror.Error
for _, pairID := range pairIDs {
for _, rule := range d.peerRulesPairs[pairID] {
if err := d.firewall.DeleteFilterRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rule %s: %w", pairID, err))
if merr != nil {
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
}
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule: %v", err)
continue
}
}
delete(d.peerRulesPairs, pairID)
}
delete(d.peerRulesPairs, pairID)
}
if err := nberrors.FormatErrorOrNil(merr); err != nil {
log.Errorf("rollback peer rules: %v", err)
}
}
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
protocol, err := ConvertToFirewallProtocol(g.protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
}
action, err := convertFirewallAction(g.action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
}
port, err := resolveGroupPort(g)
if err != nil {
return "", nil, err
}
var fwRule firewall.Rule
switch g.direction {
case mgmProto.RuleDirection_IN:
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
if shouldSkipInvertedRule(protocol, port) {
return "", nil, nil
}
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, fmt.Errorf("add firewall rule: %w", err)
}
if fwRule == nil {
return "", nil, fmt.Errorf("add firewall rule: %w", ErrNoRuleReturned)
}
// Derive the pair id from the backend rule, like the route path:
// the backend dedups identical content, so two policies authorizing
// the same flow resolve to the same id and a single backing rule.
return fwRule.ID(), []firewall.Rule{fwRule}, nil
d.peerRulesPairs = newRulePairs
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]firewall.Rule, len(rules))
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return the existing rule if already present
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules {
addedRule, err := d.applyRouteACL(rule, dynamicResolver)
id, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
@@ -295,18 +163,16 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
}
continue
}
newRouteRules[addedRule.ID()] = addedRule
newRouteRules[id] = struct{}{}
}
// Tear down old route rules; retain ones the backend refused so a
// transient failure doesn't leave orphaned rules in the firewall.
for ruleID, rule := range d.routeRules {
if _, exists := newRouteRules[ruleID]; exists {
continue
}
if err := d.firewall.DeleteFilterRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
newRouteRules[ruleID] = rule
// Clean up old firewall rules
for id := range d.routeRules {
if _, exists := newRouteRules[id]; !exists {
if err := d.firewall.DeleteRouteRule(id); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
}
// implicitly deleted from the map
}
}
@@ -314,196 +180,102 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (firewall.Rule, error) {
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return nil, ErrSourceRangesEmpty
return "", ErrSourceRangesEmpty
}
var sources []netip.Prefix
for _, sourceRange := range rule.SourceRanges {
source, err := netip.ParsePrefix(sourceRange)
if err != nil {
return nil, fmt.Errorf("parse source range: %w", err)
return "", fmt.Errorf("parse source range: %w", err)
}
sources = append(sources, firewall.UnmapPrefix(source))
sources = append(sources, source)
}
destination, err := determineDestination(rule, dynamicResolver, sources)
if err != nil {
return nil, fmt.Errorf("determine destination: %w", err)
return "", fmt.Errorf("determine destination: %w", err)
}
protocol, err := ConvertToFirewallProtocol(rule.Protocol)
protocol, err := convertToFirewallProtocol(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("invalid protocol: %w", err)
return "", fmt.Errorf("invalid protocol: %w", err)
}
action, err := convertFirewallAction(rule.Action)
if err != nil {
return nil, fmt.Errorf("invalid action: %w", err)
return "", fmt.Errorf("invalid action: %w", err)
}
dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
if err != nil {
return nil, fmt.Errorf("add route rule: %w", err)
}
if addedRule == nil {
return nil, fmt.Errorf("add route rule: %w", ErrNoRuleReturned)
return "", fmt.Errorf("add route rule: %w", err)
}
return addedRule, nil
return id.RuleID(addedRule.ID()), nil
}
// splitDenyAccept partitions groups by action so denies can be
// applied before accepts. Order within each bucket is preserved.
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
for _, g := range groups {
if g.action == mgmProto.RuleAction_DROP {
denies = append(denies, g)
} else {
accepts = append(accepts, g)
}
}
return denies, accepts
}
// groupPeerRules merges single-source rules sharing a selector into
// multi-source groups. It splits source-decode failures by action:
// denyErr is non-nil when a deny rule could not be decoded, which is a
// fail-open risk the caller must treat as fatal for the pass; err
// carries the tolerable accept-rule failures the caller can log and
// continue past.
func groupPeerRules(rules []*mgmProto.FirewallRule) (groups []*peerRuleGroup, denyErr error, err error) {
var denyMerr, acceptMerr *multierror.Error
byKey := make(map[peerRuleKey]*peerRuleGroup)
order := make([]peerRuleKey, 0)
for _, r := range rules {
srcs, decErr := extractRuleSources(r)
if decErr != nil {
if r.Action == mgmProto.RuleAction_DROP {
denyMerr = multierror.Append(denyMerr, decErr)
} else {
acceptMerr = multierror.Append(acceptMerr, decErr)
}
continue
}
// A single FirewallRule normally carries one address family, but
// split by family defensively: each backend keys a rule to one
// family and would mismatch sources of the other, so a group's
// sources must never span families.
v4, v6 := splitPrefixesByFamily(srcs)
for _, sub := range []struct {
isV6 bool
sources []netip.Prefix
}{{false, v4}, {true, v6}} {
if len(sub.sources) == 0 {
continue
}
key := ruleGroupKey(r, sub.isV6)
g, ok := byKey[key]
if !ok {
g = &peerRuleGroup{
direction: r.Direction,
action: r.Action,
protocol: r.Protocol,
port: r.PortInfo,
legacyPort: r.Port,
policyID: r.PolicyID,
}
byKey[key] = g
order = append(order, key)
}
g.sources = append(g.sources, sub.sources...)
}
}
out := make([]*peerRuleGroup, 0, len(order))
for _, k := range order {
out = append(out, byKey[k])
}
return out, nberrors.FormatErrorOrNil(denyMerr), nberrors.FormatErrorOrNil(acceptMerr)
}
func prefixIsV6(p netip.Prefix) bool {
return p.Addr().Is6() && !p.Addr().Is4In6()
}
// splitPrefixesByFamily partitions prefixes into IPv4 and IPv6 groups.
func splitPrefixesByFamily(prefixes []netip.Prefix) (v4, v6 []netip.Prefix) {
for _, p := range prefixes {
if prefixIsV6(p) {
v6 = append(v6, p)
} else {
v4 = append(v4, p)
}
}
return v4, v6
}
// ruleGroupKey builds the selector key for a rule. v6 must reflect the
// rule's source family: mgmt emits one rule per family and mixing them
// would break ICMP-variant selection in uspfilter.
func ruleGroupKey(r *mgmProto.FirewallRule, v6 bool) peerRuleKey {
k := peerRuleKey{
v6: v6,
policyID: string(r.PolicyID),
direction: r.Direction,
action: r.Action,
protocol: r.Protocol,
legacyPort: r.Port,
}
if pi := r.PortInfo; pi != nil {
k.port = uint16(pi.GetPort())
if rng := pi.GetRange(); rng != nil {
k.rangeStart = uint16(rng.GetStart())
k.rangeEnd = uint16(rng.GetEnd())
}
}
return k
}
// extractRuleSources returns all source prefixes the rule applies to.
// New management populates sourcePrefixes; older management sets PeerIP.
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
if len(r.SourcePrefixes) > 0 {
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
for _, raw := range r.SourcePrefixes {
addr, err := netiputil.DecodeAddr(raw)
if err != nil {
return nil, fmt.Errorf("decode source prefix: %w", err)
}
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
}
return out, nil
}
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (id.RuleID, []firewall.Rule, error) {
ip, err := extractRuleIP(r)
if err != nil {
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
return "", nil, err
}
addr = addr.Unmap()
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
}
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
if !portInfoEmpty(g.port) {
return convertPortInfo(g.port), nil
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
if g.legacyPort != "" {
value, err := strconv.ParseUint(g.legacyPort, 10, 16)
action, err := convertFirewallAction(r.Action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
var port *firewall.Port
if !portInfoEmpty(r.PortInfo) {
port = convertPortInfo(r.PortInfo)
} else if r.Port != "" {
// old version of management, single port
value, err := strconv.Atoi(r.Port)
if err != nil {
return nil, fmt.Errorf("invalid port: %w", err)
return "", nil, fmt.Errorf("invalid port: %w", err)
}
return &firewall.Port{
port = &firewall.Port{
Values: []uint16{uint16(value)},
}, nil
}
}
// nolint:nilnil // a nil port legitimately means "no port restriction"
return nil, nil
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
switch r.Direction {
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, err
}
return ruleID, rules, nil
}
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
@@ -522,9 +294,85 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
}
}
// ConvertToFirewallProtocol maps a management rule protocol to the
// firewall protocol type.
func ConvertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
func (d *DefaultManager) addInRules(
id []byte,
ip netip.Addr,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return rule, nil
}
func (d *DefaultManager) addOutRules(
id []byte,
ip netip.Addr,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
if shouldSkipInvertedRule(protocol, port) {
return nil, nil
}
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return rule, nil
}
// getPeerRuleID returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID(
ip netip.Addr,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
if port != nil {
idStr += port.String()
}
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
}
// extractRuleIP extracts the peer IP from a firewall rule.
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
// Otherwise fall back to the deprecated PeerIP string field (old management).
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
if len(r.SourcePrefixes) > 0 {
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
if err != nil {
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
}
return addr.Unmap(), nil
}
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
}
return addr.Unmap(), nil
}
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewall.ProtocolTCP, nil

View File

@@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
@@ -77,9 +76,9 @@ func TestDefaultManager(t *testing.T) {
})
t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[fwmanager.RuleID]struct{}{}
existedPairs := map[string]struct{}{}
for id := range acl.peerRulesPairs {
existedPairs[id] = struct{}{}
existedPairs[id.ID()] = struct{}{}
}
// remove first rule
@@ -106,7 +105,7 @@ func TestDefaultManager(t *testing.T) {
// check that old rule was removed
previousCount := 0
for id := range acl.peerRulesPairs {
if _, ok := existedPairs[id]; ok {
if _, ok := existedPairs[id.ID()]; ok {
previousCount++
}
}

View File

@@ -3,6 +3,7 @@ package auth
import (
"context"
"net/url"
"strings"
"sync"
"time"
@@ -21,6 +22,25 @@ import (
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// peerLoginExpiredMsg is the exact phrase the management server returns
// when a previously SSO-enrolled peer's login has expired. Sourced from
// shared/management/status/error.go (NewPeerLoginExpiredError). Matched
// by substring so a future server-side rewording that keeps the phrase
// still triggers the friendly fallback in Login().
const peerLoginExpiredMsg = "peer login has expired"
// errSetupKeyOnSSOExpiredPeer replaces the raw management error when the
// user runs `netbird login -k <setup-key>` against a peer that was
// originally enrolled via SSO. Wrapped in a PermissionDenied gRPC status
// so callers' existing isPermissionDenied / isAuthError checks still
// classify it correctly (early-exit from retry backoff, StatusNeedsLogin
// in the server state machine).
var errSetupKeyOnSSOExpiredPeer = status.Error(
codes.PermissionDenied,
"this peer was originally enrolled via SSO and its session has expired. "+
"Setup keys can only enrol new peers — run `netbird up` (interactive SSO) to re-login.",
)
// Auth manages authentication operations with the management server
// It maintains a long-lived connection and automatically handles reconnection with backoff
type Auth struct {
@@ -184,6 +204,15 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
// The peer pub-key is already on file with the management
// server (originally enrolled via SSO) and the session has
// expired. The setup-key path can only enrol new peers, so
// retrying with -k will keep failing. Replace the raw mgm
// message with an actionable hint that tells the user to
// re-authenticate via SSO instead.
if setupKey != "" && jwtToken == "" && isPeerLoginExpired(err) {
err = errSetupKeyOnSSOExpiredPeer
}
isAuthError = isPermissionDenied(err)
return err
}
@@ -474,3 +503,16 @@ func isLoginNeeded(err error) bool {
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}
// isPeerLoginExpired reports whether err is the management server's
// "peer login has expired" PermissionDenied response. Used by Login to
// detect the case where the caller passed a setup-key but the peer is
// actually an SSO-enrolled record whose session needs refreshing — the
// setup-key path cannot help there.
func isPeerLoginExpired(err error) bool {
if !isPermissionDenied(err) {
return false
}
s, _ := status.FromError(err)
return strings.Contains(s.Message(), peerLoginExpiredMsg)
}

View File

@@ -0,0 +1,80 @@
package auth
import (
"errors"
"strings"
"testing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestIsPeerLoginExpired(t *testing.T) {
cases := []struct {
name string
err error
want bool
}{
{
name: "nil",
err: nil,
want: false,
},
{
name: "plain error (not a gRPC status)",
err: errors.New("network read: connection reset"),
want: false,
},
{
name: "PermissionDenied with different message",
err: status.Error(codes.PermissionDenied, "user is blocked"),
want: false,
},
{
name: "Unauthenticated with the expected phrase",
// Wrong status code — must still return false.
err: status.Error(codes.Unauthenticated, "peer login has expired, please log in once more"),
want: false,
},
{
name: "exact server message",
err: status.Error(codes.PermissionDenied, "peer login has expired, please log in once more"),
want: true,
},
{
name: "phrase as substring",
// Future-proofing: if mgm reworords but keeps the phrase,
// the friendly fallback must still kick in.
err: status.Error(codes.PermissionDenied, "session refused: peer login has expired (account=foo)"),
want: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := isPeerLoginExpired(tc.err); got != tc.want {
t.Fatalf("isPeerLoginExpired(%v) = %v, want %v", tc.err, got, tc.want)
}
})
}
}
func TestErrSetupKeyOnSSOExpiredPeer(t *testing.T) {
// Sentinel must surface as PermissionDenied so the upstream
// isPermissionDenied / isAuthError checks classify it correctly
// (short-circuit retry backoff, set StatusNeedsLogin).
if !isPermissionDenied(errSetupKeyOnSSOExpiredPeer) {
t.Fatalf("errSetupKeyOnSSOExpiredPeer must be a PermissionDenied gRPC error")
}
// Message must actually mention SSO and `netbird up` so it is
// actionable for the end user. Loose substring checks keep the
// test resilient to copy edits.
s, _ := status.FromError(errSetupKeyOnSSOExpiredPeer)
msg := strings.ToLower(s.Message())
for _, want := range []string{"sso", "netbird up"} {
if !strings.Contains(msg, want) {
t.Errorf("sentinel message should contain %q, got %q", want, s.Message())
}
}
}

View File

@@ -0,0 +1,89 @@
package auth
import (
"context"
"sync"
"time"
)
// PendingFlow stores an in-progress OAuth flow between the RPC that
// initiates it (returns the verification URI to the UI) and the RPC
// that waits for the user to complete it. The flow handle, the
// device-code info, and the absolute expiry are kept together so the
// waiting RPC can validate the device code and reuse the same flow.
//
// PendingFlow is safe for concurrent use; callers must not access the
// stored fields directly.
type PendingFlow struct {
mu sync.Mutex
flow OAuthFlow
info AuthFlowInfo
expiresAt time.Time
waitCancel context.CancelFunc
}
// NewPendingFlow returns an empty PendingFlow ready to be populated by Set.
func NewPendingFlow() *PendingFlow {
return &PendingFlow{}
}
// Set stores the flow and its authorization info, computing the absolute
// expiry from info.ExpiresIn (seconds, as returned by the IdP).
func (p *PendingFlow) Set(flow OAuthFlow, info AuthFlowInfo) {
p.mu.Lock()
defer p.mu.Unlock()
p.flow = flow
p.info = info
p.expiresAt = time.Now().Add(time.Duration(info.ExpiresIn) * time.Second)
}
// Get returns the stored flow, info, and whether a flow is currently
// pending. Returns (nil, zero, false) after Clear or before Set.
func (p *PendingFlow) Get() (OAuthFlow, AuthFlowInfo, bool) {
p.mu.Lock()
defer p.mu.Unlock()
if p.flow == nil {
return nil, AuthFlowInfo{}, false
}
return p.flow, p.info, true
}
// ExpiresAt returns the absolute expiry of the pending flow. Returns
// the zero time when no flow is pending.
func (p *PendingFlow) ExpiresAt() time.Time {
p.mu.Lock()
defer p.mu.Unlock()
return p.expiresAt
}
// SetWaitCancel records the cancel function for the goroutine currently
// blocked in WaitToken so a new RequestAuth can preempt it.
func (p *PendingFlow) SetWaitCancel(cancel context.CancelFunc) {
p.mu.Lock()
defer p.mu.Unlock()
p.waitCancel = cancel
}
// CancelWait invokes and clears the stored wait-cancel, if any. Safe to
// call when no wait is in progress.
func (p *PendingFlow) CancelWait() {
p.mu.Lock()
cancel := p.waitCancel
p.waitCancel = nil
p.mu.Unlock()
if cancel != nil {
cancel()
}
}
// Clear resets the pending flow to empty. Any stored wait-cancel is
// dropped without being invoked — call CancelWait first if the waiting
// goroutine must be stopped.
func (p *PendingFlow) Clear() {
p.mu.Lock()
defer p.mu.Unlock()
p.flow = nil
p.info = AuthFlowInfo{}
p.expiresAt = time.Time{}
p.waitCancel = nil
}

View File

@@ -0,0 +1,74 @@
package sessionwatch
import (
"strconv"
"time"
)
// internal event kinds are no longer exposed: the watcher drives the Sink
// directly (NotifyStateChange on deadline change/clear, PublishEvent at
// each warning lead). Tests use a mock Sink to observe what the watcher
// emits.
// Metadata keys attached by the daemon to session-warning SystemEvents.
// The UI tray reads these to build a locale-aware notification without
// relying on the daemon's locale-less UserMessage string, and to
// disambiguate the T-WarningLead notification from the T-FinalWarningLead
// fallback that auto-opens the SessionAboutToExpire dialog.
const (
// MetaSessionWarning is set to "true" on both warning events (T-10 and
// T-2) so the UI can detect a session-warning SystemEvent without
// matching on the message text. Use MetaSessionFinal to distinguish
// the two.
MetaSessionWarning = "session_warning"
// MetaSessionFinal is set to "true" on the T-FinalWarningLead event
// only. Consumers that need to auto-open the SessionAboutToExpire
// dialog gate on this; T-WarningLead events leave the field unset.
MetaSessionFinal = "session_final_warning"
// MetaSessionExpiresAt carries the absolute UTC deadline encoded with
// FormatExpiresAt; consumers must decode with ParseExpiresAt so a
// future format change stays a single edit.
MetaSessionExpiresAt = "session_expires_at"
// MetaSessionLeadMinutes carries the lead in whole minutes (WarningLead
// for the T-10 event, FinalWarningLead for the T-2 event) so the UI
// can show "expires in ~N minutes" without hardcoding either constant.
MetaSessionLeadMinutes = "lead_minutes"
)
// expiresAtLayout is the wire format used for MetaSessionExpiresAt.
// Producer and consumers both go through FormatExpiresAt/ParseExpiresAt
// so this layout stays a single source of truth.
const expiresAtLayout = time.RFC3339
// FormatExpiresAt encodes a deadline for MetaSessionExpiresAt. Always
// emits UTC so a consumer in another timezone reads the same wall-clock
// deadline.
func FormatExpiresAt(t time.Time) string {
return t.UTC().Format(expiresAtLayout)
}
// ParseExpiresAt decodes the MetaSessionExpiresAt value back to a UTC
// time. Returns an error when the field is empty or malformed; the
// caller decides whether to fall back (zero value) or propagate.
func ParseExpiresAt(s string) (time.Time, error) {
t, err := time.Parse(expiresAtLayout, s)
if err != nil {
return time.Time{}, err
}
return t.UTC(), nil
}
// FormatLeadMinutes encodes a lead duration for MetaSessionLeadMinutes
// as the integer count of whole minutes. Sub-minute residuals are
// truncated — the field is informational ("expires in ~N minutes") and
// fractional minutes don't change what the UI displays.
func FormatLeadMinutes(d time.Duration) string {
return strconv.Itoa(int(d / time.Minute))
}
// ParseLeadMinutes decodes a MetaSessionLeadMinutes value. Returns 0
// and the parse error for malformed input; consumers that prefer a
// silent fallback can simply ignore the error.
func ParseLeadMinutes(s string) (int, error) {
return strconv.Atoi(s)
}

View File

@@ -0,0 +1,387 @@
// Package sessionwatch tracks the SSO session expiry deadline that the
// management server publishes via LoginResponse / SyncResponse and fires
// two warning events at fixed lead times before expiry: an interactive
// T-WarningLead notification and a dismiss-gated T-FinalWarningLead
// fallback dialog.
//
// The watcher is idempotent: Update may be called as often as the network
// map snapshots arrive. Repeating the same deadline is a no-op; a new
// deadline reschedules the timers and arms a fresh warning cycle.
//
// Warning firing is edge-detected. Each unique deadline value fires each
// warning callback at most once.
package sessionwatch
import (
"errors"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
cProto "github.com/netbirdio/netbird/client/proto"
)
const (
// Skew tolerates a small clock difference between the management
// server and this peer before treating a deadline as "in the past".
// Slightly above typical NTP drift; tight enough that the UI doesn't
// paint a stale expiry as if it were valid.
Skew = 30 * time.Second
// maxDeadlineHorizon caps how far in the future an accepted deadline
// can sit. A timestamp beyond this is almost certainly a protocol
// glitch, and silently arming a 100-year timer would hide the bug.
maxDeadlineHorizon = 10 * 365 * 24 * time.Hour
// WarningLead is how far before expiry the first (interactive)
// warning fires. Drives the T-10 OS notification with
// Extend/Dismiss actions.
WarningLead = 10 * time.Minute
// FinalWarningLead is how far before expiry the fallback final
// warning fires. Drives the auto-opened SessionAboutToExpire dialog,
// but only when the user has not dismissed the T-WarningLead warning
// for the same deadline. Must be strictly less than WarningLead.
FinalWarningLead = 2 * time.Minute
)
var (
// ErrDeadlineBeforeEpoch is returned by Update when the supplied
// deadline pre-dates 1970-01-01.
ErrDeadlineBeforeEpoch = errors.New("session deadline before unix epoch")
// ErrDeadlineTooFarFuture is returned by Update when the supplied
// deadline is more than maxDeadlineHorizon in the future.
ErrDeadlineTooFarFuture = errors.New("session deadline too far in the future")
// ErrDeadlineInPast is returned by Update when the supplied deadline
// is more than Skew in the past.
ErrDeadlineInPast = errors.New("session deadline in the past")
)
// StatusRecorder is the side-effect surface the watcher drives on every
// state transition. Production wires this to peer.Status (SetSessionExpiresAt
// for deadline change/clear, PublishEvent for the two warnings); tests pass
// a fake recorder so the same surface is observable without an engine.
//
// The watcher is the single owner of the deadline propagated to the
// recorder: every set, clear, sanity-check rejection and Close routes the
// value through SetSessionExpiresAt, so the SubscribeStatus snapshot the UI
// reads can never drift from the watcher's timer state. (SetSessionExpiresAt
// fans out its own state-change notification, so no separate notify is
// needed.) The recorder is server-scoped and outlives this engine-scoped
// watcher — without the Close-time clear a teardown (Down, or the Down+Up of
// a profile switch) would leave the next session showing the previous one's
// stale "expires in" value.
//
// PublishEvent's signature mirrors peer.Status.PublishEvent: the watcher
// composes the metadata internally so the wire format (MetaSession*) is
// owned by sessionwatch, not the caller.
type StatusRecorder interface {
SetSessionExpiresAt(deadline time.Time)
PublishEvent(
severity cProto.SystemEvent_Severity,
category cProto.SystemEvent_Category,
message string,
userMessage string,
metadata map[string]string,
)
}
// Watcher observes the latest session deadline and fires two warnings
// before it expires: the interactive T-WarningLead notification, and the
// fallback T-FinalWarningLead dialog (suppressed when the user dismissed
// the first one for the same deadline). Safe for concurrent use.
type Watcher struct {
lead time.Duration
finalLead time.Duration
mu sync.Mutex
current time.Time
timer *time.Timer
finalTimer *time.Timer
firedAt time.Time // deadline value the T-WarningLead callback last fired against
finalFiredAt time.Time // deadline value the T-FinalWarningLead callback last fired against
dismissedAt time.Time // deadline value the user dismissed via Dismiss(); gates fireFinal
closed bool
recorder StatusRecorder
}
// New returns a watcher with the package defaults WarningLead and
// FinalWarningLead. Pass nil for recorder to silence side effects (handy
// in unit tests that exercise sanity checks without observing the publish
// path).
func New(recorder StatusRecorder) *Watcher {
return NewWithLeads(WarningLead, FinalWarningLead, recorder)
}
// NewWithLeads returns a watcher with custom lead times. Useful for tests.
// final must be strictly less than lead; otherwise both timers fire in the
// wrong order or simultaneously and the UI flow breaks. A zero final lead
// disables the final-warning timer entirely (see armTimerLocked) so a
// millisecond-scale deadline doesn't flush both timers in one tick.
func NewWithLeads(lead, final time.Duration, recorder StatusRecorder) *Watcher {
return &Watcher{
lead: lead,
finalLead: final,
recorder: recorder,
}
}
// Update sets the latest deadline. Pass the zero time to clear (e.g. when
// a Sync push from the server omits the field because login expiration
// was disabled).
//
// Same-value updates are no-ops. A different non-zero value cancels any
// pending timer, resets the "already fired" guard, and arms a new one.
//
// Returns one of the sentinel Err* values when the deadline fails the
// sanity checks (pre-epoch, far future, or in the past beyond Skew).
// In every error case the watcher first clears its state so it stays
// consistent with what the caller will push into its other sinks (e.g.
// applySessionDeadline forces a zero deadline into the status recorder
// after a non-nil error).
func (w *Watcher) Update(deadline time.Time) error {
w.mu.Lock()
if w.closed {
w.mu.Unlock()
return nil
}
if deadline.IsZero() {
w.clearLocked()
return nil
}
now := time.Now()
switch {
case deadline.Before(time.Unix(0, 0)):
w.clearLocked()
return fmt.Errorf("%w: %v", ErrDeadlineBeforeEpoch, deadline)
case deadline.After(now.Add(maxDeadlineHorizon)):
w.clearLocked()
return fmt.Errorf("%w: %v", ErrDeadlineTooFarFuture, deadline)
case deadline.Before(now.Add(-Skew)):
w.clearLocked()
return fmt.Errorf("%w: %v (now=%v)", ErrDeadlineInPast, deadline, now)
}
if deadline.Equal(w.current) {
w.mu.Unlock()
return nil
}
w.stopTimerLocked()
w.current = deadline
// Reset every per-deadline guard so a refreshed deadline arms a fresh
// warning cycle: both edge triggers and the user Dismiss decision
// (the user agreed to the old deadline expiring; a new deadline
// restarts the contract).
w.firedAt = time.Time{}
w.finalFiredAt = time.Time{}
w.dismissedAt = time.Time{}
w.armTimerLocked(deadline)
recorder := w.recorder
w.mu.Unlock()
if recorder != nil {
recorder.SetSessionExpiresAt(deadline)
}
log.Infof("auth session deadline set to: %s (in %s)", deadline.Format(time.RFC3339), time.Until(deadline).Round(time.Second))
return nil
}
// Deadline returns the most recently observed deadline. Zero when no
// deadline is currently tracked.
func (w *Watcher) Deadline() time.Time {
w.mu.Lock()
defer w.mu.Unlock()
return w.current
}
// Dismiss records the user's "Dismiss" action against the current deadline
// and suppresses the upcoming final-warning callback for that deadline.
// Idempotent: repeated calls are no-ops. A subsequent Update with a fresh
// deadline resets the dismissal so the final-warning cycle re-arms.
//
// No-op when the watcher holds no deadline or has been closed.
func (w *Watcher) Dismiss() {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed || w.current.IsZero() {
return
}
if w.dismissedAt.Equal(w.current) {
return
}
w.dismissedAt = w.current
// Cancel the armed final-warning timer eagerly. fireFinal would also
// gate on dismissedAt, but stopping the timer avoids a wakeup with
// nothing to do and makes the intent visible.
if w.finalTimer != nil {
w.finalTimer.Stop()
w.finalTimer = nil
}
log.Infof("auth session final-warning dismissed for deadline %s", w.current.Format(time.RFC3339))
}
// Close stops any pending timer and drops the deadline on the status
// recorder. Update calls after Close are ignored. Clearing the recorder
// here is what keeps a teardown (Down, or the Down+Up of a profile switch)
// from leaving the next session showing this one's stale "expires in"
// value — the recorder is server-scoped and outlives this engine-scoped
// watcher, so nothing else drops the anchor on teardown.
func (w *Watcher) Close() {
w.mu.Lock()
if w.closed {
w.mu.Unlock()
return
}
w.closed = true
w.stopTimerLocked()
hadDeadline := !w.current.IsZero()
w.current = time.Time{}
w.firedAt = time.Time{}
w.finalFiredAt = time.Time{}
w.dismissedAt = time.Time{}
recorder := w.recorder
w.mu.Unlock()
if recorder != nil && hadDeadline {
recorder.SetSessionExpiresAt(time.Time{})
}
}
// clearLocked drops the tracked deadline and notifies the recorder so
// downstream consumers (SubscribeStatus stream, UI) drop their anchor.
// The caller must hold w.mu; this helper releases it before invoking
// the recorder.
func (w *Watcher) clearLocked() {
if w.current.IsZero() {
w.mu.Unlock()
return
}
w.stopTimerLocked()
w.current = time.Time{}
w.firedAt = time.Time{}
w.finalFiredAt = time.Time{}
w.dismissedAt = time.Time{}
recorder := w.recorder
w.mu.Unlock()
if recorder != nil {
recorder.SetSessionExpiresAt(time.Time{})
}
log.Infof("auth session deadline cleared")
}
func (w *Watcher) stopTimerLocked() {
if w.timer != nil {
w.timer.Stop()
w.timer = nil
}
if w.finalTimer != nil {
w.finalTimer.Stop()
w.finalTimer = nil
}
}
func (w *Watcher) armTimerLocked(deadline time.Time) {
w.timer = armOneShotLocked(deadline.Add(-w.lead), func() { w.fire(deadline) })
// finalLead <= 0 disables the final-warning timer entirely. Used by
// tests that predate the final-warning fallback so a millisecond-scale
// deadline does not flush both timers at once.
if w.finalLead > 0 {
w.finalTimer = armOneShotLocked(deadline.Add(-w.finalLead), func() { w.fireFinal(deadline) })
}
}
func (w *Watcher) fire(armedFor time.Time) {
w.mu.Lock()
if w.closed || !w.current.Equal(armedFor) {
// Deadline moved while we were waiting (e.g. a successful extend).
// The reschedule path armed a fresh timer; this one is stale.
w.mu.Unlock()
return
}
if !w.firedAt.IsZero() && w.firedAt.Equal(armedFor) {
w.mu.Unlock()
return
}
w.firedAt = armedFor
recorder := w.recorder
w.mu.Unlock()
if recorder == nil {
return
}
log.Infof("auth session expiry soon warning fired")
publishWarning(recorder, armedFor, false)
}
// fireFinal mirrors fire for the T-FinalWarningLead timer with an extra
// dismiss-gate: if the user dismissed the T-WarningLead notification for
// this deadline, the final warning is suppressed entirely.
func (w *Watcher) fireFinal(armedFor time.Time) {
w.mu.Lock()
if w.closed || !w.current.Equal(armedFor) {
w.mu.Unlock()
return
}
if !w.finalFiredAt.IsZero() && w.finalFiredAt.Equal(armedFor) {
w.mu.Unlock()
return
}
if w.dismissedAt.Equal(armedFor) {
w.mu.Unlock()
log.Infof("auth session final-warning skipped (dismissed by user)")
return
}
w.finalFiredAt = armedFor
recorder := w.recorder
w.mu.Unlock()
if recorder == nil {
return
}
log.Infof("auth session final-warning fired")
publishWarning(recorder, armedFor, true)
}
// armOneShotLocked schedules cb at fireAt. When fireAt is already in the
// past it dispatches on the next scheduler tick so a state-change recorder
// notification (invoked after w.mu is released) lands first. Caller must
// hold w.mu.
func armOneShotLocked(fireAt time.Time, cb func()) *time.Timer {
delay := time.Until(fireAt)
if delay <= 0 {
return time.AfterFunc(0, cb)
}
return time.AfterFunc(delay, cb)
}
// publishWarning composes the SystemEvent for a watcher-fired warning and
// pushes it through the recorder. Severity is CRITICAL on both — bypassing
// the user's Notifications toggle is deliberate: missing the warning
// window forces the post-mortem SessionExpired flow (tunnel torn down,
// lock icon, manual re-login), which is the UX we are trying to avoid.
func publishWarning(recorder StatusRecorder, deadline time.Time, final bool) {
lead := WarningLead
message := "session expiry warning"
meta := map[string]string{
MetaSessionWarning: "true",
MetaSessionExpiresAt: FormatExpiresAt(deadline),
}
if final {
lead = FinalWarningLead
message = "session expiry final warning"
meta[MetaSessionFinal] = "true"
}
meta[MetaSessionLeadMinutes] = FormatLeadMinutes(lead)
recorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_AUTHENTICATION,
message,
"",
meta,
)
}

View File

@@ -0,0 +1,519 @@
package sessionwatch
import (
"errors"
"sync"
"testing"
"time"
cProto "github.com/netbirdio/netbird/client/proto"
)
// fakeRecorder satisfies StatusRecorder and records every call so tests
// can observe what the watcher emits. SetSessionExpiresAt and PublishEvent
// land in the same ordered events slice (with the Kind distinguishing
// them) so tests that care about ordering still work. lastDeadline holds
// the most recent value passed to SetSessionExpiresAt so tests can assert
// the recorder ended up cleared/set as expected.
type fakeRecorder struct {
mu sync.Mutex
events []event
lastDeadline time.Time
}
type eventKind int
const (
stateChange eventKind = iota
publish
)
type event struct {
kind eventKind
// Set only for publish events.
severity cProto.SystemEvent_Severity
category cProto.SystemEvent_Category
message string
meta map[string]string
}
// SetSessionExpiresAt mirrors peer.Status: a same-value write is a no-op,
// a real change records the new value and fans out a state-change (the
// production recorder calls notifyStateChange internally). The baseline
// is the zero time, so an initial clear before any deadline is set emits
// nothing — matching the real recorder.
func (r *fakeRecorder) SetSessionExpiresAt(deadline time.Time) {
r.mu.Lock()
defer r.mu.Unlock()
if r.lastDeadline.Equal(deadline) {
return
}
r.lastDeadline = deadline
r.events = append(r.events, event{kind: stateChange})
}
func (r *fakeRecorder) deadline() time.Time {
r.mu.Lock()
defer r.mu.Unlock()
return r.lastDeadline
}
func (r *fakeRecorder) PublishEvent(
severity cProto.SystemEvent_Severity,
category cProto.SystemEvent_Category,
message string,
_ string,
metadata map[string]string,
) {
r.mu.Lock()
defer r.mu.Unlock()
r.events = append(r.events, event{
kind: publish,
severity: severity,
category: category,
message: message,
meta: metadata,
})
}
func (r *fakeRecorder) snapshot() []event {
r.mu.Lock()
defer r.mu.Unlock()
out := make([]event, len(r.events))
copy(out, r.events)
return out
}
func (e event) isFinalWarning() bool {
return e.kind == publish && e.meta[MetaSessionFinal] == "true"
}
func (e event) isWarning() bool {
return e.kind == publish && e.meta[MetaSessionWarning] == "true" && e.meta[MetaSessionFinal] != "true"
}
func countWhere(events []event, pred func(event) bool) int {
n := 0
for _, e := range events {
if pred(e) {
n++
}
}
return n
}
func waitForEvents(t *testing.T, r *fakeRecorder, want int) []event {
t.Helper()
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
if got := r.snapshot(); len(got) >= want {
return got
}
time.Sleep(5 * time.Millisecond)
}
got := r.snapshot()
t.Fatalf("timed out waiting for %d events, got %d: %+v", want, len(got), got)
return nil
}
// newWatcher builds a watcher with the final timer disabled (finalLead=0),
// matching the lead-only behaviour the pre-final-warning tests assume.
func newWatcher(lead time.Duration, r *fakeRecorder) *Watcher {
return NewWithLeads(lead, 0, r)
}
func TestUpdateZeroBeforeAnythingIsNoop(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(50*time.Millisecond, r)
defer w.Close()
_ = w.Update(time.Time{})
if got := r.snapshot(); len(got) != 0 {
t.Fatalf("expected no events on initial zero, got %+v", got)
}
}
func TestUpdateNonZeroFiresStateChange(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(50*time.Millisecond, r)
defer w.Close()
d := time.Now().Add(time.Hour)
_ = w.Update(d)
events := waitForEvents(t, r, 1)
if events[0].kind != stateChange {
t.Fatalf("expected stateChange, got %+v", events[0])
}
if !w.Deadline().Equal(d) {
t.Fatalf("deadline mismatch: %v vs %v", w.Deadline(), d)
}
}
func TestSameDeadlineIsNoop(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(50*time.Millisecond, r)
defer w.Close()
d := time.Now().Add(time.Hour)
_ = w.Update(d)
_ = w.Update(d)
_ = w.Update(d)
events := waitForEvents(t, r, 1)
if len(events) != 1 {
t.Fatalf("expected exactly 1 event for repeated same deadline, got %d: %+v", len(events), events)
}
}
func TestWarningFiresOnceWithinLeadWindow(t *testing.T) {
r := &fakeRecorder{}
lead := 50 * time.Millisecond
w := newWatcher(lead, r)
defer w.Close()
// Deadline 80ms out — warning should fire after ~30ms.
d := time.Now().Add(80 * time.Millisecond)
_ = w.Update(d)
events := waitForEvents(t, r, 2)
if events[0].kind != stateChange {
t.Fatalf("event[0] should be stateChange, got %+v", events[0])
}
if !events[1].isWarning() {
t.Fatalf("event[1] should be a warning publish, got %+v", events[1])
}
}
func TestWarningFiresImmediatelyWhenAlreadyInsideWindow(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(time.Hour, r) // lead > delta => fire immediately
defer w.Close()
d := time.Now().Add(10 * time.Millisecond)
_ = w.Update(d)
events := waitForEvents(t, r, 2)
if !events[1].isWarning() {
t.Fatalf("expected immediate warning publish, got %+v", events[1])
}
}
func TestNewDeadlineCancelsPriorTimer(t *testing.T) {
r := &fakeRecorder{}
lead := 50 * time.Millisecond
w := newWatcher(lead, r)
defer w.Close()
first := time.Now().Add(80 * time.Millisecond) // would fire warning ~30ms in
_ = w.Update(first)
// Replace with a far-future deadline before the warning fires.
time.Sleep(5 * time.Millisecond)
second := time.Now().Add(time.Hour)
_ = w.Update(second)
// Wait past when first's warning would have fired.
time.Sleep(80 * time.Millisecond)
if n := countWhere(r.snapshot(), event.isWarning); n != 0 {
t.Fatalf("warning fired for cancelled deadline: %+v", r.snapshot())
}
}
func TestRefreshAfterFireArmsNewWarning(t *testing.T) {
r := &fakeRecorder{}
lead := 30 * time.Millisecond
w := newWatcher(lead, r)
defer w.Close()
first := time.Now().Add(50 * time.Millisecond)
_ = w.Update(first)
// Wait for stateChange + warning of the first cycle.
waitForEvents(t, r, 2)
// Simulate a successful extend: brand new deadline.
second := time.Now().Add(60 * time.Millisecond)
_ = w.Update(second)
// 4 events total: stateChange, warning (first), stateChange, warning (second).
events := waitForEvents(t, r, 4)
if events[2].kind != stateChange {
t.Fatalf("event[2] should be stateChange for the new deadline, got %+v", events[2])
}
if !events[3].isWarning() {
t.Fatalf("event[3] should be a warning publish for the new deadline, got %+v", events[3])
}
}
func TestUpdateZeroAfterNonZeroClearsState(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(time.Hour, r)
defer w.Close()
d := time.Now().Add(2 * time.Hour)
_ = w.Update(d)
waitForEvents(t, r, 1)
_ = w.Update(time.Time{})
events := waitForEvents(t, r, 2)
if events[1].kind != stateChange {
t.Fatalf("expected stateChange on clear, got %+v", events[1])
}
if !w.Deadline().IsZero() {
t.Fatalf("Deadline should be zero after clear")
}
}
func TestUpdateRejectsBeforeEpoch(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(50*time.Millisecond, r)
defer w.Close()
good := time.Now().Add(time.Hour)
if err := w.Update(good); err != nil {
t.Fatalf("seed Update: %v", err)
}
err := w.Update(time.Unix(-100, 0))
if !errors.Is(err, ErrDeadlineBeforeEpoch) {
t.Fatalf("want ErrDeadlineBeforeEpoch, got %v", err)
}
if !w.Deadline().IsZero() {
t.Fatalf("rejected pre-epoch update must clear deadline; got %v", w.Deadline())
}
}
func TestUpdateRejectsTooFarFuture(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(50*time.Millisecond, r)
defer w.Close()
good := time.Now().Add(time.Hour)
if err := w.Update(good); err != nil {
t.Fatalf("seed Update: %v", err)
}
err := w.Update(time.Now().Add(50 * 365 * 24 * time.Hour))
if !errors.Is(err, ErrDeadlineTooFarFuture) {
t.Fatalf("want ErrDeadlineTooFarFuture, got %v", err)
}
if !w.Deadline().IsZero() {
t.Fatalf("rejected far-future update must clear deadline; got %v", w.Deadline())
}
}
func TestUpdateInPastClearsDeadline(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(50*time.Millisecond, r)
defer w.Close()
good := time.Now().Add(time.Hour)
if err := w.Update(good); err != nil {
t.Fatalf("seed Update: %v", err)
}
// Drain the stateChange from the seed.
waitForEvents(t, r, 1)
err := w.Update(time.Now().Add(-1 * time.Hour))
if !errors.Is(err, ErrDeadlineInPast) {
t.Fatalf("want ErrDeadlineInPast, got %v", err)
}
if !w.Deadline().IsZero() {
t.Fatalf("in-past update must clear the deadline, got %v", w.Deadline())
}
events := waitForEvents(t, r, 2)
if events[1].kind != stateChange {
t.Fatalf("expected stateChange on clear, got %+v", events[1])
}
}
func TestUpdateWithinSkewAccepted(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(50*time.Millisecond, r)
defer w.Close()
// 5 seconds in the past is within the 30s Skew tolerance — accept it.
d := time.Now().Add(-5 * time.Second)
if err := w.Update(d); err != nil {
t.Fatalf("within-skew Update should succeed, got %v", err)
}
if !w.Deadline().Equal(d) {
t.Fatalf("expected deadline to be applied, got %v want %v", w.Deadline(), d)
}
}
func TestCloseSilencesUpdates(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(50*time.Millisecond, r)
w.Close()
_ = w.Update(time.Now().Add(time.Hour))
time.Sleep(20 * time.Millisecond)
if got := r.snapshot(); len(got) != 0 {
t.Fatalf("expected no events after Close, got %+v", got)
}
}
// TestCloseClearsRecorderDeadline pins the profile-switch fix: a watcher
// holding a live deadline must zero the recorder on Close so the next
// engine's watcher (and the UI reading the shared server-scoped recorder)
// doesn't start out showing the previous session's stale "expires in".
func TestCloseClearsRecorderDeadline(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(time.Hour, r)
d := time.Now().Add(2 * time.Hour)
if err := w.Update(d); err != nil {
t.Fatalf("seed Update: %v", err)
}
if got := r.deadline(); !got.Equal(d) {
t.Fatalf("recorder deadline after Update = %v, want %v", got, d)
}
w.Close()
if got := r.deadline(); !got.IsZero() {
t.Fatalf("recorder deadline after Close = %v, want zero", got)
}
}
// TestCloseWithoutDeadlineLeavesRecorderUntouched guards the symmetric
// case: closing a watcher that never held a deadline must not emit a
// redundant clear (the recorder may legitimately hold a value written by
// some other path; the watcher only owns what it set).
func TestCloseWithoutDeadlineLeavesRecorderUntouched(t *testing.T) {
r := &fakeRecorder{}
w := newWatcher(time.Hour, r)
w.Close()
if got := r.snapshot(); len(got) != 0 {
t.Fatalf("expected no events from Close on an empty watcher, got %+v", got)
}
}
func TestFinalWarningFiresAfterRegularWarning(t *testing.T) {
r := &fakeRecorder{}
// Warning fires at deadline-80ms, final at deadline-30ms.
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
defer w.Close()
d := time.Now().Add(100 * time.Millisecond)
_ = w.Update(d)
// Expect stateChange + warning + final-warning.
events := waitForEvents(t, r, 3)
if countWhere(events, func(e event) bool { return e.kind == stateChange }) != 1 {
t.Fatalf("expected exactly 1 stateChange, got %+v", events)
}
if countWhere(events, event.isWarning) != 1 {
t.Fatalf("expected exactly 1 warning publish, got %+v", events)
}
if countWhere(events, event.isFinalWarning) != 1 {
t.Fatalf("expected exactly 1 final-warning publish, got %+v", events)
}
// Warning must precede final (same deadline, longer lead fires first).
var wIdx, fIdx int
for i, e := range events {
switch {
case e.isWarning():
wIdx = i
case e.isFinalWarning():
fIdx = i
}
}
if wIdx > fIdx {
t.Fatalf("warning must publish before final-warning, got order %+v", events)
}
}
func TestDismissSuppressesFinalWarning(t *testing.T) {
r := &fakeRecorder{}
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
defer w.Close()
d := time.Now().Add(100 * time.Millisecond)
_ = w.Update(d)
// Wait for the warning publish so we know we're inside the warning
// window, then dismiss before the final timer would fire.
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
if countWhere(r.snapshot(), event.isWarning) >= 1 {
break
}
time.Sleep(2 * time.Millisecond)
}
if countWhere(r.snapshot(), event.isWarning) < 1 {
t.Fatalf("warning did not publish in time, events=%+v", r.snapshot())
}
w.Dismiss()
// Now wait past when the final would have fired.
time.Sleep(120 * time.Millisecond)
if n := countWhere(r.snapshot(), event.isFinalWarning); n != 0 {
t.Fatalf("final-warning published after Dismiss(), events=%+v", r.snapshot())
}
}
func TestDismissResetByNewDeadline(t *testing.T) {
r := &fakeRecorder{}
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
defer w.Close()
first := time.Now().Add(100 * time.Millisecond)
_ = w.Update(first)
// Dismiss against the first deadline.
w.Dismiss()
// Replace with a fresh deadline before the first's timers complete.
time.Sleep(10 * time.Millisecond)
second := time.Now().Add(100 * time.Millisecond)
_ = w.Update(second)
// The second cycle must publish a final-warning (the dismiss state
// did not carry over).
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
if countWhere(r.snapshot(), event.isFinalWarning) >= 1 {
break
}
time.Sleep(5 * time.Millisecond)
}
if countWhere(r.snapshot(), event.isFinalWarning) < 1 {
t.Fatalf("final-warning did not publish on fresh deadline after Dismiss reset, events=%+v", r.snapshot())
}
}
func TestDismissBeforeUpdateIsNoop(t *testing.T) {
r := &fakeRecorder{}
w := NewWithLeads(80*time.Millisecond, 30*time.Millisecond, r)
defer w.Close()
// No deadline tracked yet; Dismiss must be a no-op (no panic, no state).
w.Dismiss()
d := time.Now().Add(100 * time.Millisecond)
_ = w.Update(d)
// Final warning should still publish — Dismiss only acts on the current
// deadline, and there was none at the time of the call.
deadline := time.Now().Add(500 * time.Millisecond)
for time.Now().Before(deadline) {
if countWhere(r.snapshot(), event.isFinalWarning) >= 1 {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("final-warning did not publish after no-op pre-Update Dismiss, events=%+v", r.snapshot())
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net"
"net/netip"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
@@ -257,6 +256,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
// On daemon shutdown / Down() the parent context is cancelled
// and the dial fails with "context canceled". Wrapping that
// into state would leave the snapshot stuck at Connecting+err
// until the backoff loop wakes up — instead let the operation
// return cleanly so the deferred state.Set(StatusIdle) takes
// effect on the next iteration.
if c.ctx.Err() != nil {
return nil
}
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
@@ -347,11 +355,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
engineConfig.TempDir = mobileDependency.TempDir
// Leave StateDir empty when there is no state path so a disk-backed
// syncstore falls back to os.TempDir() instead of filepath.Dir("") == ".".
if path != "" {
engineConfig.StateDir = filepath.Dir(path)
}
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)
@@ -390,6 +393,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
// Seed the session-expiry deadline from the LoginResponse. Subsequent
// changes flow in through SyncResponse and are applied in handleSync.
engine.ApplySessionDeadline(loginResp.GetSessionExpiresAt())
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
@@ -430,7 +437,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
c.statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff)
// Wrap the backoff with c.ctx so Down()/actCancel propagates into the
// inter-attempt sleep — otherwise a 15s MaxInterval can keep the retry
// loop alive long after the caller asked to give up, leaving the
// status stream stuck at Connecting.
err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {

View File

@@ -900,7 +900,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err
}
pf, err := uspfilter.Create(uspfilter.Config{IFace: wgIface, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err

View File

@@ -3,6 +3,7 @@ package dnsfwd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"strconv"
@@ -159,13 +160,12 @@ func (m *Manager) allowDNSFirewall() error {
return nil
}
anyV4 := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
dnsRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept)
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil {
return fmt.Errorf("add udp firewall rule: %w", err)
}
tcpRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept)
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
if err != nil {
return fmt.Errorf("add tcp firewall rule: %w", err)
}
@@ -174,12 +174,8 @@ func (m *Manager) allowDNSFirewall() error {
return fmt.Errorf("flush: %w", err)
}
if dnsRule != nil {
m.fwRules = []firewall.Rule{dnsRule}
}
if tcpRule != nil {
m.tcpRules = []firewall.Rule{tcpRule}
}
m.fwRules = dnsRules
m.tcpRules = tcpRules
m.registerNetstackServices()
@@ -213,12 +209,12 @@ func (m *Manager) unregisterNetstackServices() {
func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range m.fwRules {
if err := m.firewall.DeleteFilterRule(rule); err != nil {
if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
for _, rule := range m.tcpRules {
if err := m.firewall.DeleteFilterRule(rule); err != nil {
if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}

View File

@@ -22,6 +22,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/proto"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
@@ -55,7 +56,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/syncstore"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
@@ -148,10 +148,6 @@ type EngineConfig struct {
LogPath string
TempDir string
// StateDir is the directory holding the state file. The sync response
// (network map) is serialized here on platforms that persist it to disk.
StateDir string
}
// EngineServices holds the external service dependencies required by the Engine.
@@ -230,15 +226,10 @@ type Engine struct {
afpacketCapture *capture.AFPacketCapture
// Sync response persistence (protected by syncRespMux).
// syncStore is nil unless persistence has been enabled; its presence is
// what marks persistence as active. The backend (disk or memory) is
// selected per-platform; see the syncstore package. syncStoreDir is where
// a disk-backed store serializes to.
syncRespMux sync.RWMutex
syncStore syncstore.Store
syncStoreDir string
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
flowManager nftypes.FlowManager
// auto-update
@@ -259,6 +250,20 @@ type Engine struct {
jobExecutorWG sync.WaitGroup
exposeManager *expose.Manager
sessionWatcher sessionDeadlineWatcher
}
// sessionDeadlineWatcher is the engine-facing surface of the SSO session
// expiry watcher. The concrete implementation (sessionwatch.Watcher) is wired
// in via newSessionWatcher, which is build-tagged so the js/wasm build links a
// no-op stub instead of pulling the full sessionwatch package (and its timer
// machinery) into the binary — the wasm client never runs the engine's
// session-warning flow.
type sessionDeadlineWatcher interface {
Update(deadline time.Time) error
Dismiss()
Close()
}
// Peer is an instance of the Connection Peer
@@ -301,8 +306,18 @@ func NewEngine(
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
syncStoreDir: config.StateDir,
}
// sessionWatcher keeps the SubscribeStatus consumers in sync with the
// session expiry deadline. Deadline-change ticks come for free via
// Status.SetSessionExpiresAt; the watcher exists to push a wake-up at
// T-WarningLead and T-FinalWarningLead so the UI repaints the remaining
// time / warning state even when nothing else changed, and to publish
// two SystemEvents (the warning composition lives in sessionwatch so
// the wire format stays owned by one package):
// - T-WarningLead → interactive "Extend now / Dismiss" notification
// - T-FinalWarningLead → auto-opened SessionAboutToExpire dialog,
// suppressed when the user dismissed the earlier warning
engine.sessionWatcher = newSessionWatcher(engine.statusRecorder)
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
@@ -343,6 +358,10 @@ func (e *Engine) Stop() error {
e.srWatcher.Close()
}
if e.sessionWatcher != nil {
e.sessionWatcher.Close()
}
if e.updateManager != nil {
e.updateManager.SetDownloadOnly()
}
@@ -650,14 +669,14 @@ func (e *Engine) initFirewall() error {
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
if _, err := e.firewall.AddFilterRule(
if _, err := e.firewall.AddPeerFiltering(
nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
firewallManager.Network{},
net.IP{0, 0, 0, 0},
firewallManager.ProtocolUDP,
nil,
&port,
firewallManager.ActionAccept,
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
@@ -707,7 +726,7 @@ func (e *Engine) blockLanAccess() {
if network.Addr().Is6() {
source = v6
}
if _, err := e.firewall.AddFilterRule(
if _, err := e.firewall.AddRouteFiltering(
nil,
[]netip.Prefix{source},
firewallManager.Network{Prefix: network},
@@ -875,6 +894,8 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
e.ApplySessionDeadline(update.GetSessionExpiresAt())
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
@@ -923,19 +944,20 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
// the whole Set so the store cannot be cleared (disabled / engine close)
// mid-call and have this write resurrect a file that was just removed.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
if e.syncStore != nil {
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
} else {
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
}
enabled := e.persistSyncResponse
e.syncRespMux.RUnlock()
// Store sync response if persistence is enabled
if enabled {
e.syncRespMux.Lock()
e.latestSyncResponse = update
e.syncRespMux.Unlock()
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
return err
@@ -1151,7 +1173,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)
e.RunHealthProbes(e.ctx, true)
},
}
@@ -1822,18 +1844,6 @@ func (e *Engine) close() {
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
// Drop any persisted sync response so its network map does not linger on
// disk after the engine stops (and cannot leak into a later run).
e.syncRespMux.Lock()
store := e.syncStore
e.syncStore = nil
e.syncRespMux.Unlock()
if store != nil {
if err := store.Clear(); err != nil {
log.Warnf("failed to clear persisted sync response on close: %v", err)
}
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -2048,7 +2058,20 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
//
// ctx scopes the (potentially slow) STUN/TURN probing: a caller that gives up —
// e.g. a Status RPC whose client disconnected — cancels its ctx and the probe
// returns instead of running to its per-component timeout. The engine's own
// lifetime ctx still applies independently, so an engine shutdown aborts the
// probe even if the caller's ctx is context.Background().
func (e *Engine) RunHealthProbes(ctx context.Context, waitForResult bool) bool {
// Tie the caller's ctx to the engine lifetime: either cancelling aborts
// the probe below.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stop := context.AfterFunc(e.ctx, cancel)
defer stop()
e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy()
@@ -2071,9 +2094,9 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
if runtime.GOOS != "js" {
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
results = e.probeStunTurn.ProbeAllWaitResult(ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
results = e.probeStunTurn.ProbeAll(ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
@@ -2163,42 +2186,45 @@ func (e *Engine) stopDNSServer() {
e.statusRecorder.UpdateDNSStates(nsGroupStates)
}
// SetSyncResponsePersistence enables or disables sync response persistence.
// The store is only instantiated while persistence is enabled; construction
// itself drops any stale data left over from an earlier run (see syncstore).
// SetSyncResponsePersistence enables or disables sync response persistence
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
e.syncRespMux.Lock()
defer e.syncRespMux.Unlock()
if enabled == (e.syncStore != nil) {
if enabled == e.persistSyncResponse {
return
}
e.persistSyncResponse = enabled
log.Debugf("Sync response persistence is set to %t", enabled)
if !enabled {
if err := e.syncStore.Clear(); err != nil {
log.Warnf("failed to clear persisted sync response: %v", err)
}
e.syncStore = nil
return
e.latestSyncResponse = nil
}
e.syncStore = syncstore.New(e.syncStoreDir)
}
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
// Hold the lock for the whole Get so the store cannot be cleared
// (disabled / engine close) mid-call.
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
enabled := e.persistSyncResponse
latest := e.latestSyncResponse
e.syncRespMux.RUnlock()
if e.syncStore == nil {
if !enabled {
return nil, errors.New("sync response persistence is disabled")
}
//nolint:nilnil
return e.syncStore.Get()
if latest == nil {
//nolint:nilnil
return nil, nil
}
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
if !ok {
return nil, fmt.Errorf("failed to clone sync response")
}
return sr, nil
}
// GetWgAddr returns the wireguard address
@@ -2234,7 +2260,7 @@ func (e *Engine) updateDNSForwarder(
enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry,
) {
if e.config.DisableServerRoutes || e.config.BlockInbound {
if e.config.DisableServerRoutes {
return
}
@@ -2387,7 +2413,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
var merr *multierror.Error
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
for _, rule := range rules {
proto, err := acl.ConvertToFirewallProtocol(rule.GetProtocol())
proto, err := convertToFirewallProtocol(rule.GetProtocol())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
continue

View File

@@ -0,0 +1,99 @@
package internal
import (
"context"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/system"
)
// ApplySessionDeadline propagates the absolute SSO session deadline carried on
// LoginResponse / SyncResponse to both the watcher (for the edge-triggered
// warning) and the status recorder (for the SubscribeStatus / Status RPC
// snapshot the UI consumes).
//
// The wire field is 3-state:
// - nil → snapshot carries no info; keep the
// previously-anchored deadline (no-op)
// - explicit zero (s=0, n=0) → peer is not SSO-registered or expiry is
// disabled; clear both sinks
// - valid timestamp → new deadline; arm watcher, expose on
// status recorder
//
// Deadline sanity-checks live in sessionwatch.Watcher.Update. Any rejected
// value is treated as a clear on both sinks: the alternative — leaving the
// previously-known deadline in place — risks the UI confidently displaying
// a stale "expires in X" while the server has actually invalidated it.
func (e *Engine) ApplySessionDeadline(ts *timestamppb.Timestamp) {
if ts == nil {
return
}
var deadline time.Time
// Explicit zero (seconds=0 AND nanos=0) is the sentinel for "disabled".
// Everything else flows through Watcher.Update, whose sanity-checks
// reject out-of-range / pre-epoch / far-future / too-stale values and
// clear on rejection.
if ts.GetSeconds() != 0 || ts.GetNanos() != 0 {
deadline = ts.AsTime().UTC()
}
if e.sessionWatcher == nil {
return
}
// Watcher.Update owns the propagation to the status recorder (the
// SubscribeStatus / Status snapshot the UI reads): a set writes the
// deadline, a clear or a sanity-check rejection writes the zero value.
// Keeping a single writer is what stops the recorder from drifting out
// of sync with the warning timers.
if err := e.sessionWatcher.Update(deadline); err != nil {
log.Errorf("auth session deadline rejected: %v, clearing", err)
}
}
// DismissSessionWarning records the user's "Dismiss" click on the
// T-WarningLead interactive notification and suppresses the upcoming
// T-FinalWarningLead fallback for the current deadline. No-op when the
// watcher is not running or holds no deadline.
func (e *Engine) DismissSessionWarning() {
if e.sessionWatcher == nil {
return
}
e.sessionWatcher.Dismiss()
}
// ExtendAuthSession asks the management server to refresh the SSO session
// expiry deadline using the supplied JWT, then mirrors the new deadline into
// the daemon's state. The tunnel is untouched; no resync, no reconnect.
//
// Returns the new absolute UTC deadline (or zero time when the server
// reports the peer is not eligible for extension).
func (e *Engine) ExtendAuthSession(ctx context.Context, jwtToken string) (time.Time, error) {
if jwtToken == "" {
return time.Time{}, errors.New("jwt token is required")
}
if e.mgmClient == nil {
return time.Time{}, errors.New("management client is not initialised")
}
info, err := system.GetInfoWithChecks(ctx, e.checks)
if err != nil {
log.Warnf("failed to collect system info for session extend: %v", err)
info = system.GetInfo(ctx)
}
resp, err := e.mgmClient.ExtendAuthSession(info, jwtToken)
if err != nil {
return time.Time{}, fmt.Errorf("extend auth session on management: %w", err)
}
e.ApplySessionDeadline(resp.GetSessionExpiresAt())
if resp.GetSessionExpiresAt().IsValid() {
return resp.GetSessionExpiresAt().AsTime().UTC(), nil
}
return time.Time{}, nil
}

View File

@@ -0,0 +1,78 @@
package internal
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
"github.com/netbirdio/netbird/client/internal/peer"
)
// TestApplySessionDeadline_ThreeState pins down the 3-state semantics of the
// wire field carried on LoginResponse / SyncResponse:
//
// - nil pointer → no info; previously-anchored deadline survives
// - explicit zero value → "expiry disabled" sentinel; both sinks cleared
// - valid future timestamp → new deadline propagated to both sinks
func TestApplySessionDeadline_ThreeState(t *testing.T) {
newEngine := func() *Engine {
recorder := peer.NewRecorder("")
return &Engine{
statusRecorder: recorder,
sessionWatcher: sessionwatch.New(recorder),
}
}
t.Run("valid timestamp sets deadline on both sinks", func(t *testing.T) {
e := newEngine()
deadline := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
e.ApplySessionDeadline(timestamppb.New(deadline))
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(deadline),
"status recorder should hold the new deadline")
})
t.Run("nil is a no-op and preserves previous deadline", func(t *testing.T) {
e := newEngine()
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
e.ApplySessionDeadline(timestamppb.New(seeded))
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
e.ApplySessionDeadline(nil)
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded),
"nil snapshot must not disturb the existing deadline")
})
t.Run("explicit zero clears a previously-anchored deadline", func(t *testing.T) {
e := newEngine()
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
e.ApplySessionDeadline(timestamppb.New(seeded))
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
// Explicit zero Timestamp{} (seconds=0, nanos=0) is the
// "expiry disabled / not SSO" sentinel.
e.ApplySessionDeadline(&timestamppb.Timestamp{})
require.True(t, e.statusRecorder.GetSessionExpiresAt().IsZero(),
"explicit zero sentinel must clear the deadline")
})
t.Run("invalid timestamp clears the deadline", func(t *testing.T) {
e := newEngine()
seeded := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
e.ApplySessionDeadline(timestamppb.New(seeded))
require.True(t, e.statusRecorder.GetSessionExpiresAt().Equal(seeded))
// Out-of-range nanos → IsValid()==false; same-meaning as the
// disabled sentinel for downstream sinks.
e.ApplySessionDeadline(&timestamppb.Timestamp{Seconds: 1, Nanos: -1})
require.True(t, e.statusRecorder.GetSessionExpiresAt().IsZero(),
"invalid timestamp must clear the deadline")
})
}

View File

@@ -0,0 +1,16 @@
//go:build !js
package internal
import (
"github.com/netbirdio/netbird/client/internal/auth/sessionwatch"
"github.com/netbirdio/netbird/client/internal/peer"
)
// newSessionWatcher returns the real SSO session expiry watcher for every
// non-wasm build. The js/wasm build gets a no-op stub from
// engine_sessionwatch_js.go so the sessionwatch package (and its timer
// machinery) never links into the wasm binary.
func newSessionWatcher(recorder *peer.Status) sessionDeadlineWatcher {
return sessionwatch.New(recorder)
}

View File

@@ -0,0 +1,39 @@
//go:build js
package internal
import (
"time"
"github.com/netbirdio/netbird/client/internal/peer"
)
// noopSessionWatcher is the js/wasm stand-in for sessionwatch.Watcher. The
// wasm client never runs the engine's session-warning flow (the interactive
// T-WarningLead notification and the T-FinalWarningLead fallback dialog live
// in the desktop UI), so linking the full sessionwatch package (timers, event
// composition) would only bloat the binary.
//
// It still mirrors the deadline into the status recorder so the SubscribeStatus
// / Status snapshot the UI consumes stays correct — only the timer-driven
// warnings are dropped.
type noopSessionWatcher struct {
recorder *peer.Status
}
func newSessionWatcher(recorder *peer.Status) sessionDeadlineWatcher {
return noopSessionWatcher{recorder: recorder}
}
// Update mirrors the real watcher's recorder propagation without the timers or
// sanity-check sentinels: a valid deadline is exposed on the status snapshot,
// the zero time clears it.
func (w noopSessionWatcher) Update(deadline time.Time) error {
if w.recorder != nil {
w.recorder.SetSessionExpiresAt(deadline)
}
return nil
}
func (noopSessionWatcher) Dismiss() {}
func (noopSessionWatcher) Close() {}

View File

@@ -24,14 +24,14 @@ type RulePair struct {
type Manager struct {
dnatFirewall DNATFirewall
rules map[firewall.RuleID]RulePair
rules map[string]RulePair // keys is the ID of the ForwardRule
rulesMu sync.Mutex
}
func NewManager(dnatFirewall DNATFirewall) *Manager {
return &Manager{
dnatFirewall: dnatFirewall,
rules: make(map[firewall.RuleID]RulePair),
rules: make(map[string]RulePair),
}
}
@@ -41,7 +41,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
var mErr *multierror.Error
toDelete := make(map[firewall.RuleID]RulePair, len(h.rules))
toDelete := make(map[string]RulePair, len(h.rules))
for id, r := range h.rules {
toDelete[id] = r
}
@@ -59,10 +59,6 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
continue
}
if rule == nil {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': backend returned no rule", fwdRule.String()))
continue
}
log.Infof("forward rule has been added '%s'", fwdRule)
h.rules[id] = RulePair{
ForwardRule: fwdRule,
@@ -94,7 +90,7 @@ func (h *Manager) Close() error {
}
}
h.rules = make(map[firewall.RuleID]RulePair)
h.rules = make(map[string]RulePair)
return nberrors.FormatErrorOrNil(mErr)
}

View File

@@ -14,11 +14,11 @@ var (
)
type MocFwRule struct {
id firewall.RuleID
id string
}
func (m *MocFwRule) ID() firewall.RuleID {
return m.id
func (m *MocFwRule) ID() string {
return string(m.id)
}
type MockDNATFirewall struct {

View File

@@ -4,8 +4,6 @@ import (
"strings"
"github.com/hashicorp/go-version"
nbversion "github.com/netbirdio/netbird/version"
)
var (
@@ -13,7 +11,7 @@ var (
)
func IsSupported(agentVersion string) bool {
if nbversion.IsDevelopmentVersion(agentVersion) {
if agentVersion == "development" {
return true
}

Some files were not shown because too many files have changed in this diff Show More