Compare commits

...

49 Commits

Author SHA1 Message Date
Claude
b621a2628b [client] Use port 22338 for RDP sideband auth server
https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
2026-04-11 18:21:22 +00:00
Claude
4949ca6194 [client] Add opt-in --allow-server-rdp flag, reuse SSH ACL, dynamic DLL registration
- Add ServerRDPAllowed field to daemon proto, EngineConfig, and profile config
- Add --allow-server-rdp flag to `netbird up` (opt-in, defaults to false)
- Wire RDP server start/stop in engine based on the flag
- Reuse SSH ACL (SSHAuth proto) for RDP authorization via sshauth.Authorizer
- Register/unregister credential provider COM DLL dynamically when flag is toggled
- Ship DLL alongside netbird.exe, register via regsvr32 at runtime (not install time)
- Update SetConfig tests to cover the new field

https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
2026-04-11 17:48:53 +00:00
Claude
c5186f1483 [client] Add RDP token passthrough for passwordless Windows Remote Desktop
Implement sideband authorization and credential provider architecture for
passwordless RDP access to Windows peers via NetBird.

Go components:
- Sideband RDP auth server (TCP on WG interface, port 3390/22023)
- Pending session store with TTL expiry and replay protection
- Named pipe IPC server (\\.\pipe\netbird-rdp-auth) for credential provider
- Sideband client for connecting peer to request authorization
- CLI command `netbird rdp [user@]host` with JWT auth flow
- Engine integration with DNAT port redirection

Rust credential provider DLL (client/rdp/credprov/):
- COM DLL implementing ICredentialProvider + ICredentialProviderCredential
- Loaded by Windows LogonUI.exe at the RDP login screen
- Queries NetBird agent via named pipe for pending sessions
- Performs S4U logon (LsaLogonUser) for passwordless Windows token creation
- Self-registration via regsvr32 (DllRegisterServer/DllUnregisterServer)

https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
2026-04-11 17:15:42 +00:00
Pascal Fischer
5259e5df51 [management] add domain and service cleanup migration (#5850) 2026-04-11 12:00:40 +02:00
Zoltan Papp
ebd78e0122 [client] Update RaceDial to accept context for improved cancellation handling (#5849) 2026-04-10 20:51:04 +02:00
Pascal Fischer
cf86b9a528 [management] enable access log cleanup by default (#5842) 2026-04-10 17:07:27 +02:00
Pascal Fischer
ee588e1536 Revert "[management] allow local routing peer resource (#5814)" (#5847) 2026-04-10 14:53:47 +02:00
Pascal Fischer
2a8aacc5c9 [management] allow local routing peer resource (#5814) 2026-04-10 13:08:21 +02:00
Pascal Fischer
15709bc666 [management] update account delete with proper proxy domain and service cleanup (#5817) 2026-04-10 13:08:04 +02:00
Pascal Fischer
789b4113fe [misc] update dashboards (#5840) 2026-04-10 12:15:58 +02:00
Viktor Liu
d2cdc0efec [client] Use native firewall for peer ACLs in userspace WireGuard mode (#5668) 2026-04-10 09:12:13 +08:00
Pascal Fischer
ee343d5d77 [management] use sql null vars (#5844) 2026-04-09 18:12:38 +02:00
Maycon Santos
099c493b18 [management] network map tests (#5795)
* Add network map benchmark and correctness test files

* Add tests for network map components correctness and edge cases

* Skip benchmarks in CI and enhance network map test coverage with new helper functions

* Remove legacy network map benchmarks and tests; refactor components-based test coverage for clarity and scalability.
2026-04-08 21:28:29 +02:00
Pascal Fischer
c1d1229ae0 [management] use NullBool for terminated flag (#5829) 2026-04-08 21:08:43 +02:00
Viktor Liu
94a36cb53e [client] Handle UPnP routers that only support permanent leases (#5826) 2026-04-08 17:59:59 +02:00
Viktor Liu
c7ba931466 [client] Populate network addresses in FreeBSD system info (#5827) 2026-04-08 17:14:16 +02:00
Viktor Liu
413d95b740 [client] Include service.json in debug bundle (#5825)
* Include service.json in debug bundle

* Add tests for service params sanitization logic
2026-04-08 21:10:31 +08:00
Viktor Liu
332c624c55 [client] Don't abort UI debug bundle when up/down fails (#5780) 2026-04-08 10:33:46 +02:00
Viktor Liu
dc160aff36 [client] Fix SSH proxy stripping shell quoting from forwarded commands (#5669) 2026-04-08 10:25:57 +02:00
Zoltan Papp
96806bf55f [relay] Replace net.Conn with context-aware Conn interface (#5770)
* [relay] Replace net.Conn with context-aware Conn interface for relay transports

Introduce a listener.Conn interface with context-based Read/Write methods,
replacing net.Conn throughout the relay server. This enables proper timeout
propagation (e.g. handshake timeout) without goroutine-based workarounds
and removes unused LocalAddr/SetDeadline methods from WS and QUIC conns.

* [relay] Refactor Peer context management to ensure proper cleanup

Integrate context creation (`context.WithCancel`) directly in `NewPeer` and remove redundant initialization in `Work`. Add `ctxCancel` calls to ensure context is properly canceled during `Close` operations.
2026-04-08 09:38:31 +02:00
Viktor Liu
d33cd4c95b [client] Add NAT-PMP/UPnP support (#5202) 2026-04-08 15:29:32 +08:00
Maycon Santos
e2c2f64be7 [client] Fix iOS DNS upstream routing for deselected exit nodes (#5803)
- Add GetSelectedClientRoutes() to the route manager that filters through FilterSelectedExitNodes, returning only active routes instead of all management routes              
  - Use GetSelectedClientRoutes() in the DNS route checker so deselected exit nodes' 0.0.0.0/0 no longer matches upstream DNS IPs — this prevented the resolver from switching
  away from the utun-bound socket after exit node deselection                                                                                                                   
  - Initialize iOS DNS server with host DNS fallback addresses (1.1.1.1:53, 1.0.0.1:53) and a permanent root zone handler, matching Android's behavior — without this, unmatched
   DNS queries arriving via the 0.0.0.0/0 tunnel route had no handler and were silently dropped
2026-04-08 08:43:48 +02:00
Viktor Liu
cb73b94ffb [client] Add TCP DNS support for local listener (#5758) 2026-04-08 07:40:36 +02:00
Viktor Liu
1d920d700c [client] Fix SSH server Stop() deadlock when sessions are active (#5717) 2026-04-07 17:56:54 +02:00
Viktor Liu
bb85eee40a [client] Skip down interfaces in network address collection for posture checks (#5768) 2026-04-07 17:56:48 +02:00
Viktor Liu
aba5d6f0d2 [client] Error out on netbird expose when block inbound is enabled (#5818) 2026-04-07 17:55:35 +02:00
Viktor Liu
0588d2dbe1 [management] Load missing service columns in pgx account loader (#5816) 2026-04-07 14:56:56 +02:00
Pascal Fischer
14b3b77bda [management] validate permissions on groups read with name (#5749) 2026-04-07 14:13:09 +02:00
Zoltan Papp
6da34e483c [client] Fix mgmProber interface to match unexported GetServerPublicKey (#5815)
Update the mgmProber interface to use HealthCheck() instead of the
now-unexported GetServerPublicKey(), aligning with the changes in the
management client API.
2026-04-07 13:13:38 +02:00
Zoltan Papp
0efef671d7 [client] Unexport GetServerPublicKey, add HealthCheck method (#5735)
* Unexport GetServerPublicKey, add HealthCheck method

Internalize server key fetching into Login, Register,
GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods,
removing the need for callers to fetch and pass the key separately.

Replace the exported GetServerPublicKey with a HealthCheck() error
method for connection validation, keeping IsHealthy() bool for
non-blocking background monitoring.

Fix test encryption to use correct key pairs (client public key as
remotePubKey instead of server private key).

* Refactor `doMgmLogin` to return only error, removing unused response
2026-04-07 12:18:21 +02:00
Eduard Gert
435203b13b [proxy] Update proxy web packages (#5661)
* [proxy] Update package-lock.json

* Update packages
2026-04-07 10:35:09 +02:00
Maycon Santos
decb5dd3af [client] Add GetSelectedClientRoutes to route manager and update DNS route check (#5802)
- DNS resolution broke after deselecting an exit node because the route checker used all client routes (including deselected ones) to decide how to forward upstream DNS
  queries
  - Added GetSelectedClientRoutes() to the route manager that filters out deselected exit nodes, and switched the DNS route checker to use it
  - Confirmed fix via device testing: after deselecting exit node, DNS queries now correctly use a regular network socket instead of binding to the utun interface
2026-04-05 13:44:53 +02:00
Viktor Liu
28fbf96b2a [client] Fix flaky TestServiceLifecycle/Restart on FreeBSD (#5786) 2026-04-02 21:45:49 +02:00
Bethuel Mmbaga
9d1a37c644 [management,client] Revert gRPC client secret removal (#5781)
* This reverts commit e5914e4e8b

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Deprecate client secret in proto

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-04-02 18:21:00 +02:00
Viktor Liu
5bf2372c4d [management] Fix L4 service creation deadlock on single-connection databases (#5779) 2026-04-02 14:46:14 +02:00
Bethuel Mmbaga
c2c6396a04 [management] Allow updating embedded IdP user name and email (#5721) 2026-04-02 13:02:10 +03:00
Misha Bragin
aaf813fc0c Add selfhosted scaling note (#5769) 2026-04-01 19:23:39 +02:00
Vlad
d97fe84296 [management] fix race condition in the setup flow that enables creation of multiple owner users (#5754) 2026-04-01 16:25:35 +02:00
tham-le
81f45dab21 [client] Support embed.Client on Android with netstack mode (#5623)
* [client] Support embed.Client on Android with netstack mode

embed.Client.Start() calls ConnectClient.Run() which passes an empty
MobileDependency{}. On Android, the engine dereferences nil fields
(IFaceDiscover, NetworkChangeListener, DnsReadyListener) causing panics.

Provide complete no-op stubs so the engine's existing Android code
paths work unchanged — zero modifications to engine.go:

- Add androidRunOverride hook in Run() for Android-specific dispatch
- Add runOnAndroidEmbed() with complete MobileDependency (all stubs)
- Wire default stubs via init() in connect_android_default.go:
  noopIFaceDiscover, noopNetworkChangeListener, noopDnsReadyListener
- Forward logPath to c.run()

Tested: embed.Client starts on Android arm64, joins mesh via relay,
discovers peers, localhost proxy works for TCP+UDP forwarding.

* [client] Fix TestServiceParamsPath for Windows path separators

Use filepath.Join in test assertions instead of hardcoded POSIX paths
so the test passes on Windows where filepath.Join uses backslashes.
2026-04-01 16:19:34 +02:00
Zoltan Papp
d670e7382a [client] Fix ipv6 address in quic server (#5763)
* [client] Use `net.JoinHostPort` for consistency in constructing host-port pairs

* [client] Fix handling of IPv6 addresses by trimming brackets in `net.JoinHostPort`
2026-04-01 15:11:23 +02:00
Pascal Fischer
cd8c686339 [misc] add path traversal and file size protections (#5755) 2026-04-01 14:23:24 +02:00
Pascal Fischer
f5c41e3018 [misc] set permissions on env file for getting started scripts (#5761) 2026-04-01 14:13:53 +02:00
Pascal Fischer
2477f99d89 [proxy] Add pprof (#5764) 2026-04-01 14:10:41 +02:00
shuuri-labs
940f530ac2 [management] Legacy to embedded IdP migration tool (#5586) 2026-04-01 13:53:19 +02:00
Zoltan Papp
4d3e2f8ad3 Fix path join (#5762) 2026-04-01 13:21:19 +02:00
Vlad
5ae986e1c4 [management] fix panic on management reboot (#5759) 2026-04-01 12:31:30 +02:00
Bethuel Mmbaga
e5914e4e8b [management,client] Remove client secret from gRPC auth flow (#5751)
Remove client secret from gRPC auth flow. The secret was originally included to support providers like Google Workspace that don't offer a proper PKCE flow, but this is no longer necessary with the embedded IdP. Deployments using such providers should migrate to the embedded IdP instead.
2026-03-31 18:50:49 +03:00
Pascal Fischer
c238f5425f [management] proper module permission validation for posture check delete (#5742) 2026-03-31 16:43:49 +02:00
Pascal Fischer
3c3097ea74 [management] add target user account validation (#5741) 2026-03-31 16:43:16 +02:00
188 changed files with 14512 additions and 2057 deletions

View File

@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -154,6 +154,26 @@ builds:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-idp-migrate
dir: tools/idp-migrate
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-idp-migrate
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -166,6 +186,10 @@ archives:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
- id: netbird-idp-migrate
builds:
- netbird-idp-migrate
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
nfpms:
- maintainer: Netbird <dev@netbird.io>

View File

@@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
needsRestoreUp := false
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = !stateWasDown
cmd.Println("netbird down")
}
@@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = false
cmd.Println("netbird up")
}
@@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
if needsRestoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up (restored)")
}
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())

View File

@@ -14,6 +14,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
@@ -201,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %w", err)
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
}
if err := handleExposeReady(cmd, stream, port); err != nil {
@@ -236,7 +237,7 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)

276
client/cmd/rdp.go Normal file
View File

@@ -0,0 +1,276 @@
package cmd
import (
"context"
"errors"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
rdpclient "github.com/netbirdio/netbird/client/rdp/client"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
const (
serverRDPAllowedFlag = "allow-server-rdp"
)
var (
rdpUsername string
rdpHost string
rdpNoBrowser bool
rdpNoCache bool
serverRDPAllowed bool
)
func init() {
rdpCmd.PersistentFlags().StringVarP(&rdpUsername, "user", "u", "", "Windows username on remote peer")
rdpCmd.PersistentFlags().BoolVar(&rdpNoBrowser, noBrowserFlag, false, noBrowserDesc)
rdpCmd.PersistentFlags().BoolVar(&rdpNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
upCmd.PersistentFlags().BoolVar(&serverRDPAllowed, serverRDPAllowedFlag, false, "Allow RDP passthrough on peer (passwordless RDP via credential provider)")
}
var rdpCmd = &cobra.Command{
Use: "rdp [flags] [user@]host",
Short: "Connect to a NetBird peer via RDP (passwordless)",
Long: `Connect to a NetBird peer using Remote Desktop Protocol with token-based
passwordless authentication. The target peer must have RDP passthrough enabled.
This command:
1. Obtains a JWT token via OIDC authentication
2. Sends the token to the target peer's sideband auth service
3. If authorized, launches mstsc.exe to connect
Examples:
netbird rdp peer-hostname
netbird rdp administrator@peer-hostname
netbird rdp --user admin peer-hostname`,
Args: cobra.MinimumNArgs(1),
RunE: rdpFn,
}
func rdpFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
// Parse user@host
if err := parseRDPHostArg(args[0]); err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
rdpCtx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runRDP(rdpCtx, cmd); err != nil {
errCh <- err
}
cancel()
}()
select {
case <-sig:
cancel()
<-rdpCtx.Done()
return nil
case err := <-errCh:
return err
case <-rdpCtx.Done():
}
return nil
}
func parseRDPHostArg(arg string) error {
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
if rdpUsername == "" {
rdpUsername = parts[0]
}
rdpHost = parts[1]
} else {
rdpHost = arg
}
if rdpUsername == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
rdpUsername = sudoUser
} else if currentUser, err := user.Current(); err == nil {
rdpUsername = currentUser.Username
} else {
rdpUsername = "Administrator"
}
}
return nil
}
func runRDP(ctx context.Context, cmd *cobra.Command) error {
// Connect to daemon
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() { _ = grpcConn.Close() }()
daemonClient := proto.NewDaemonServiceClient(grpcConn)
// Resolve peer IP
peerIP, err := resolvePeerIP(ctx, daemonClient, rdpHost)
if err != nil {
return fmt.Errorf("resolve peer %s: %w", rdpHost, err)
}
cmd.Printf("Connecting to %s@%s (%s)...\n", rdpUsername, rdpHost, peerIP)
// Obtain JWT token
hint := profilemanager.GetLoginHint()
var browserOpener func(string) error
if !rdpNoBrowser {
browserOpener = util.OpenBrowser
}
jwtToken, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !rdpNoCache, hint, browserOpener)
if err != nil {
return fmt.Errorf("JWT authentication: %w", err)
}
log.Debug("JWT authentication successful")
cmd.Println("Authenticated. Requesting RDP access...")
// Generate nonce for replay protection
nonce, err := rdpserver.GenerateNonce()
if err != nil {
return fmt.Errorf("generate nonce: %w", err)
}
// Send sideband auth request
authClient := rdpclient.New()
authAddr := net.JoinHostPort(peerIP, fmt.Sprintf("%d", rdpserver.DefaultRDPAuthPort))
resp, err := authClient.RequestAuth(ctx, authAddr, &rdpserver.AuthRequest{
JWTToken: jwtToken,
RequestedUser: rdpUsername,
ClientPeerIP: "", // will be filled by the server from the connection
Nonce: nonce,
})
if err != nil {
cmd.Printf("Failed to authorize RDP session with %s\n", rdpHost)
cmd.Printf("\nTroubleshooting:\n")
cmd.Printf(" 1. Check connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify RDP passthrough is enabled on the target peer\n")
return fmt.Errorf("sideband auth: %w", err)
}
if resp.Status != rdpserver.StatusAuthorized {
return fmt.Errorf("RDP access denied: %s", resp.Reason)
}
cmd.Printf("RDP access authorized (session: %s, user: %s)\n", resp.SessionID, resp.OSUser)
cmd.Printf("Launching Remote Desktop client...\n")
// Launch mstsc.exe (platform-specific)
if err := launchRDPClient(peerIP); err != nil {
return fmt.Errorf("launch RDP client: %w", err)
}
return nil
}
// resolvePeerIP resolves a peer hostname/FQDN to its WireGuard IP address
// by querying the daemon for the current peer status.
func resolvePeerIP(ctx context.Context, client proto.DaemonServiceClient, peerAddress string) (string, error) {
statusResp, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return "", fmt.Errorf("get daemon status: %w", err)
}
if statusResp.GetFullStatus() == nil {
return "", errors.New("daemon returned empty status")
}
for _, peer := range statusResp.GetFullStatus().GetPeers() {
if matchesPeer(peer, peerAddress) {
ip := peer.GetIP()
if ip == "" {
continue
}
// Strip CIDR suffix if present
if idx := strings.Index(ip, "/"); idx != -1 {
ip = ip[:idx]
}
return ip, nil
}
}
// If not found as a peer name, try as a direct IP
if addr, err := net.ResolveIPAddr("ip", peerAddress); err == nil {
return addr.String(), nil
}
return "", fmt.Errorf("peer %q not found in network", peerAddress)
}
func matchesPeer(peer *proto.PeerState, address string) bool {
address = strings.ToLower(address)
if strings.EqualFold(peer.GetFqdn(), address) {
return true
}
// Match against FQDN without trailing dot
fqdn := strings.TrimSuffix(peer.GetFqdn(), ".")
if strings.EqualFold(fqdn, address) {
return true
}
// Match against short hostname (first part of FQDN)
if parts := strings.SplitN(fqdn, ".", 2); len(parts) > 0 {
if strings.EqualFold(parts[0], address) {
return true
}
}
// Match against IP
ip := peer.GetIP()
if idx := strings.Index(ip, "/"); idx != -1 {
ip = ip[:idx]
}
if ip == address {
return true
}
return false
}

13
client/cmd/rdp_stub.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows
package cmd
import "fmt"
// launchRDPClient is a stub for non-Windows platforms.
func launchRDPClient(peerIP string) error {
fmt.Printf("RDP session authorized for %s\n", peerIP)
fmt.Println("Note: mstsc.exe is only available on Windows.")
fmt.Printf("Use any RDP client to connect to %s:3389\n", peerIP)
return nil
}

34
client/cmd/rdp_windows.go Normal file
View File

@@ -0,0 +1,34 @@
//go:build windows
package cmd
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
// launchRDPClient launches the native Windows Remote Desktop client (mstsc.exe).
func launchRDPClient(peerIP string) error {
mstscPath, err := exec.LookPath("mstsc.exe")
if err != nil {
return fmt.Errorf("mstsc.exe not found: %w", err)
}
cmd := exec.Command(mstscPath, fmt.Sprintf("/v:%s", peerIP))
if err := cmd.Start(); err != nil {
return fmt.Errorf("start mstsc.exe: %w", err)
}
log.Debugf("launched mstsc.exe (PID %d) connecting to %s", cmd.Process.Pid, peerIP)
// Don't wait for mstsc to exit - it runs independently
go func() {
if err := cmd.Wait(); err != nil {
log.Debugf("mstsc.exe exited: %v", err)
}
}()
return nil
}

View File

@@ -150,6 +150,7 @@ func init() {
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(rdpCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)

View File

@@ -25,10 +25,10 @@ func TestServiceParamsPath(t *testing.T) {
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = "/var/lib/netbird"
assert.Equal(t, "/var/lib/netbird/service.json", serviceParamsPath())
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
configs.StateDir = "/custom/state"
assert.Equal(t, "/custom/state/service.json", serviceParamsPath())
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
}
func TestSaveAndLoadServiceParams(t *testing.T) {

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
@@ -13,6 +15,22 @@ import (
"github.com/stretchr/testify/require"
)
// TestMain intercepts when this test binary is run as a daemon subprocess.
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
// "service run ..." arguments. Since the test binary can't handle cobra CLI
// args, it exits immediately, causing daemon -r to respawn rapidly until
// hitting the rate limit and exiting. This makes service restart unreliable.
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
<-sig
return
}
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {

View File

@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
req.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -458,6 +461,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
ic.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
@@ -582,6 +588,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
loginRequest.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"os"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
@@ -35,20 +36,27 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
log.Info("forcing userspace firewall")
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() {
return fm, err
}
// Fall back to the userspace packet filter if native is unavailable
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
return fm, nil
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
@@ -160,3 +168,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
func forceUserspaceFirewall() bool {
val := os.Getenv(EnvForceUserspaceFirewall)
if val == "" {
return false
}
force, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
return false
}
return force
}

View File

@@ -7,6 +7,12 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
// native iptables/nftables is available. This only applies when the WireGuard interface
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
// kernel netfilter rules.
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string

View File

@@ -33,7 +33,6 @@ type Manager struct {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Create iptables firewall manager
@@ -64,10 +63,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)
@@ -203,12 +201,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
@@ -286,6 +282,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"

View File

@@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)

View File

@@ -36,6 +36,7 @@ const (
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
@@ -43,6 +44,7 @@ const (
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpNatOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
@@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
log.Debug("flushing routing related tables")
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
for _, chainInfo := range []struct {
chain string
table string
@@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
@@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNatOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNatOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil

View File

@@ -9,10 +9,9 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex

View File

@@ -169,6 +169,14 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error

View File

@@ -40,7 +40,6 @@ func getTableName() string {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Manager of iptables firewall
@@ -106,10 +105,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -205,12 +203,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -346,6 +342,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"

View File

@@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface

View File

@@ -36,6 +36,7 @@ const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
@@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -8,10 +8,9 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}

View File

@@ -140,6 +140,17 @@ type Manager struct {
mtu uint16
mssClampValue uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -594,6 +605,8 @@ func (m *Manager) resetState() {
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
if m.udpTracker != nil {
m.udpTracker.Close()
@@ -713,6 +726,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
@@ -895,38 +911,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
@@ -1278,12 +1277,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
@@ -1342,65 +1335,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return sourceMatched
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
udpHook: hook,
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
}
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
return fmt.Errorf("hook with given id not found")
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
}{
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
var called bool
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
called = true
return true
})
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
var addedRule PeerRule
if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
assert.Nil(t, manager.udpHookOut.Load())
}
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
}
})
}
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
var called bool
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
called = true
return true
})
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
assert.Nil(t, manager.tcpHookOut.Load())
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
}
}
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
if !found {
t.Fatalf("The hook was not added properly.")
}
// Now remove the packet hook
err = manager.RemovePacketHook(hookID)
if err != nil {
t.Fatalf("Failed to remove hook: %s", err)
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
}
func TestProcessOutgoingHooks(t *testing.T) {
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
}
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
manager.SetUDPPacketHook(
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{

View File

@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}

View File

@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {

View File

@@ -18,9 +18,7 @@ type PeerRule struct {
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
udpHook func([]byte) bool
drop bool
}
// ID returns the rule id

View File

@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
return true // drop (intercepted by hook)
})
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted,
},
expectedAllow: false,

View File

@@ -15,14 +15,17 @@ type PacketFilter interface {
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one UDP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one TCP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
}
// FilteredDevice to override Read or Write of packets

View File

@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
// SetUDPPacketHook mocks base method.
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
}
// SetTCPPacketHook mocks base method.
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
}
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
}
// FilterInbound mocks base method.
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
}
// RemovePacketHook mocks base method.
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemovePacketHook indicates an expected call of RemovePacketHook.
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}

View File

@@ -1,87 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -19,6 +19,9 @@ import (
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
@@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) {
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
@@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) {
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
@@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
@@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()

View File

@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
protoFlow, err := client.GetPKCEAuthorizationFlow()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
@@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
protoFlow, err := client.GetDeviceAuthorizationFlow()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
@@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
return err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err

View File

@@ -44,6 +44,10 @@ import (
"github.com/netbirdio/netbird/version"
)
// androidRunOverride is set on Android to inject mobile dependencies
// when using embed.Client (which calls Run() with empty MobileDependency).
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
@@ -76,6 +80,9 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
if androidRunOverride != nil {
return androidRunOverride(c, runningChan, logPath)
}
return c.run(MobileDependency{}, runningChan, logPath)
}
@@ -104,6 +111,7 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -113,6 +121,7 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
@@ -534,6 +543,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerRDPAllowed: config.ServerRDPAllowed != nil && *config.ServerRDPAllowed,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
@@ -610,12 +620,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
@@ -634,12 +638,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}
return loginResp, nil
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
}
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {

View File

@@ -0,0 +1,73 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
// It returns an empty interface list, which means ICE P2P candidates won't be
// discovered — connections will fall back to relay. Applications that need P2P
// should provide a real implementation via runOnAndroidEmbed that uses
// Android's ConnectivityManager to enumerate network interfaces.
type noopIFaceDiscover struct{}
func (noopIFaceDiscover) IFaces() (string, error) {
// Return empty JSON array — no local interfaces advertised for ICE.
// This is intentional: without Android's ConnectivityManager, we cannot
// reliably enumerate interfaces (netlink is restricted on Android 11+).
// Relay connections still work; only P2P hole-punching is disabled.
return "[]", nil
}
// noopNetworkChangeListener is a stub for embed.Client on Android.
// Network change events are ignored since the embed client manages its own
// reconnection logic via the engine's built-in retry mechanism.
type noopNetworkChangeListener struct{}
func (noopNetworkChangeListener) OnNetworkChanged(string) {
// No-op: embed.Client relies on the engine's internal reconnection
// logic rather than OS-level network change notifications.
}
func (noopNetworkChangeListener) SetInterfaceIP(string) {
// No-op: in netstack mode, the overlay IP is managed by the userspace
// network stack, not by OS-level interface configuration.
}
// noopDnsReadyListener is a stub for embed.Client on Android.
// DNS readiness notifications are not needed in netstack/embed mode
// since system DNS is disabled and DNS resolution happens externally.
type noopDnsReadyListener struct{}
func (noopDnsReadyListener) OnReady() {
// No-op: embed.Client does not need DNS readiness notifications.
// System DNS is disabled in netstack mode.
}
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
var _ dns.ReadyListener = noopDnsReadyListener{}
func init() {
// Wire up the default override so embed.Client.Start() works on Android
// with netstack mode. Provides complete no-op stubs for all mobile
// dependencies so the engine's existing Android code paths work unchanged.
// Applications that need P2P ICE or real DNS should replace this by
// setting androidRunOverride before calling Start().
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
return c.runOnAndroidEmbed(
noopIFaceDiscover{},
noopNetworkChangeListener{},
[]netip.AddrPort{},
noopDnsReadyListener{},
runningChan,
logPath,
)
}
}

View File

@@ -0,0 +1,32 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
// so embed.Client.Start() can detect when the engine is ready.
// It provides complete MobileDependency so the engine's existing
// Android code paths work unchanged.
func (c *ConnectClient) runOnAndroidEmbed(
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
runningChan chan struct{},
logPath string,
) error {
mobileDependency := MobileDependency{
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, runningChan, logPath)
}

View File

@@ -25,6 +25,7 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updater/installer"
@@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states for the active profile.
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information.
@@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
}
if err := g.addServiceParams(); err != nil {
log.Errorf("failed to add service params to debug bundle: %v", err)
}
if err := g.addMetrics(); err != nil {
log.Errorf("failed to add metrics to debug bundle: %v", err)
}
@@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error {
return nil
}
const (
serviceParamsFile = "service.json"
serviceParamsBundle = "service_params.json"
maskedValue = "***"
envVarPrefix = "NB_"
jsonKeyManagementURL = "management_url"
jsonKeyServiceEnv = "service_env_vars"
)
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
func (g *BundleGenerator) addServiceParams() error {
path := filepath.Join(configs.StateDir, serviceParamsFile)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("read service params: %w", err)
}
var params map[string]any
if err := json.Unmarshal(data, &params); err != nil {
return fmt.Errorf("parse service params: %w", err)
}
if g.anonymize {
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
}
}
g.sanitizeServiceEnvVars(params)
sanitizedData, err := json.MarshalIndent(params, "", " ")
if err != nil {
return fmt.Errorf("marshal sanitized service params: %w", err)
}
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
return fmt.Errorf("add service params to zip: %w", err)
}
return nil
}
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
if !ok {
return
}
sanitized := make(map[string]any, len(envVars))
for k, v := range envVars {
val, _ := v.(string)
switch {
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
sanitized[k] = maskedValue
case g.anonymize:
sanitized[k] = g.anonymizer.AnonymizeString(val)
default:
sanitized[k] = val
}
}
params[jsonKeyServiceEnv] = sanitized
}
// isSensitiveEnvVar returns true for env var names that may contain secrets.
func isSensitiveEnvVar(key string) bool {
lower := strings.ToLower(key)
for _, s := range sensitiveEnvSubstrings {
if strings.Contains(lower, s) {
return true
}
}
return false
}
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")

View File

@@ -1,8 +1,12 @@
package debug
import (
"archive/zip"
"bytes"
"encoding/json"
"net"
"os"
"path/filepath"
"strings"
"testing"
@@ -10,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
}
}
func TestIsSensitiveEnvVar(t *testing.T) {
tests := []struct {
key string
sensitive bool
}{
{"NB_SETUP_KEY", true},
{"NB_API_TOKEN", true},
{"NB_CLIENT_SECRET", true},
{"NB_PASSWORD", true},
{"NB_CREDENTIAL", true},
{"NB_LOG_LEVEL", false},
{"NB_MANAGEMENT_URL", false},
{"NB_HOSTNAME", false},
{"HOME", false},
{"PATH", false},
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
})
}
}
func TestSanitizeServiceEnvVars(t *testing.T) {
tests := []struct {
name string
anonymize bool
input map[string]any
check func(t *testing.T, params map[string]any)
}{
{
name: "no env vars key",
anonymize: false,
input: map[string]any{"management_url": "https://mgmt.example.com"},
check: func(t *testing.T, params map[string]any) {
t.Helper()
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
_, ok := params[jsonKeyServiceEnv]
assert.False(t, ok, "service_env_vars should not be added")
},
},
{
name: "non-NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "sensitive NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "safe NB vars anonymized when anonymize is true",
anonymize: true,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
"NB_LOG_LEVEL": "debug",
"NB_SETUP_KEY": "secret",
"SOME_OTHER": "val",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
// Safe NB_ values should be anonymized (not the original, not masked)
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
logVal := env["NB_LOG_LEVEL"].(string)
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
// Sensitive and non-NB_ still masked
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
assert.Equal(t, maskedValue, env["SOME_OTHER"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
g := &BundleGenerator{
anonymize: tt.anonymize,
anonymizer: anonymizer,
}
g.sanitizeServiceEnvVars(tt.input)
tt.check(t, tt.input)
})
}
}
func TestAddServiceParams(t *testing.T) {
t.Run("missing service.json returns nil", func(t *testing.T) {
g := &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
}
origStateDir := configs.StateDir
configs.StateDir = t.TempDir()
t.Cleanup(func() { configs.StateDir = origStateDir })
err := g.addServiceParams()
assert.NoError(t, err)
})
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
jsonKeyServiceEnv: map[string]any{
"NB_LOG_LEVEL": "trace",
},
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: true,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
require.Len(t, zr.File, 1)
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
mgmt := result[jsonKeyManagementURL].(string)
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
assert.NotEmpty(t, mgmt)
env := result[jsonKeyServiceEnv].(map[string]any)
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
})
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: false,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
})
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{

View File

@@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
return nil
}
w.response = m
if m.MsgHdr.Truncated {
w.SetMeta("truncated", "true")
}
return w.ResponseWriter.WriteMsg(m)
}
@@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
question := r.Question[0]
qname := strings.ToLower(question.Name)
@@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
cw.response.Len(), meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {

View File

@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
})
}
// TestLocalResolver_Stop tests cleanup on Stop
// TestLocalResolver_Stop tests cleanup on GracefullyStop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
t.Run("GracefullyStop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop()
})
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})

View File

@@ -90,6 +90,11 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// Mock implementation - no-op
}
// SetFirewall mock implementation of SetFirewall from Server interface
func (m *MockServer) SetFirewall(Firewall) {
// Mock implementation - no-op
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op

View File

@@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) {
// After a call to Hijack(), the DNS package will not do anything with the connection.
func (r *responseWriter) Hijack() {
}
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
var srcIP net.IP
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
srcIP = ipv4.(*layers.IPv4).SrcIP
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
srcIP = ipv6.(*layers.IPv6).SrcIP
}
var srcPort int
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
srcPort = int(udp.(*layers.UDP).SrcPort)
}
if srcIP == nil {
return nil
}
return &net.UDPAddr{IP: srcIP, Port: srcPort}
}

View File

@@ -58,6 +58,7 @@ type Server interface {
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
SetRouteChecker(func(netip.Addr) bool)
SetFirewall(Firewall)
}
type nsGroupsByDomain struct {
@@ -151,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(config.WgInterface)
} else {
dnsService = newServiceViaListener(config.WgInterface, addrPort)
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
}
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
@@ -186,11 +187,16 @@ func NewDefaultServerIos(
ctx context.Context,
wgInterface WGIface,
iosDnsManager IosDnsManager,
hostsDnsList []netip.AddrPort,
statusRecorder *peer.Status,
disableSys bool,
) *DefaultServer {
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.iosDnsManager = iosDnsManager
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
return ds
}
@@ -374,6 +380,17 @@ func (s *DefaultServer) DnsIP() netip.Addr {
return s.service.RuntimeIP()
}
// SetFirewall sets the firewall used for DNS port DNAT rules.
// This must be called before Initialize when using the listener-based service,
// because the firewall is typically not available at construction time.
func (s *DefaultServer) SetFirewall(fw Firewall) {
if svc, ok := s.service.(*serviceViaListener); ok {
svc.listenerFlagLock.Lock()
svc.firewall = fw
svc.listenerFlagLock.Unlock()
}
}
// Stop stops the server
func (s *DefaultServer) Stop() {
s.probeMu.Lock()
@@ -395,8 +412,12 @@ func (s *DefaultServer) Stop() {
maps.Clear(s.extraDomains)
}
func (s *DefaultServer) disableDNS() error {
defer s.service.Stop()
func (s *DefaultServer) disableDNS() (retErr error) {
defer func() {
if err := s.service.Stop(); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
}
}()
if s.isUsingNoopHostManager() {
return nil

View File

@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
type mockService struct{}
func (m *mockService) Listen() error { return nil }
func (m *mockService) Stop() {}
func (m *mockService) Stop() error { return nil }
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
func (m *mockService) RuntimePort() int { return 53 }
func (m *mockService) RegisterMux(string, dns.Handler) {}

View File

@@ -4,15 +4,25 @@ import (
"net/netip"
"github.com/miekg/dns"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
DefaultPort = 53
)
// Firewall provides DNAT capabilities for DNS port redirection.
// This is used when the DNS server cannot bind port 53 directly
// and needs firewall rules to redirect traffic.
type Firewall interface {
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
}
type service interface {
Listen() error
Stop()
Stop() error
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int

View File

@@ -10,9 +10,13 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
@@ -31,25 +35,33 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
tcpServer *dns.Server
listenIP netip.Addr
listenPort uint16
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
firewall Firewall
tcpDNATConfigured bool
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
firewall: fw,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
tcpServer: &dns.Server{
Net: "tcp",
Handler: mux,
},
}
return s
@@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error {
return fmt.Errorf("eval listen address: %w", err)
}
s.listenIP = s.listenIP.Unmap()
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
s.server.Addr = addr
s.tcpServer.Addr = addr
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
log.Debugf("starting dns on %s (UDP + TCP)", addr)
s.listenerIsRunning = true
go func() {
if err := s.server.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
}
s.listenerFlagLock.Lock()
unexpected := s.listenerIsRunning
s.listenerIsRunning = false
s.listenerFlagLock.Unlock()
if unexpected {
if err := s.tcpServer.Shutdown(); err != nil {
log.Debugf("failed to shutdown DNS TCP server: %v", err)
}
}
}()
go func() {
if err := s.tcpServer.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
}
}()
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
// a DNAT rule because eBPF only handles UDP.
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
} else {
s.tcpDNATConfigured = true
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
}
}
return nil
}
func (s *serviceViaListener) Stop() {
func (s *serviceViaListener) Stop() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
return nil
}
s.listenerIsRunning = false
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
var merr *multierror.Error
if err := s.server.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
}
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
}
if s.tcpDNATConfigured && s.firewall != nil {
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
s.tcpDNATConfigured = false
}
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
if err := s.ebpfService.FreeDNSFwd(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
return s.listenIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running
}
// evalListenAddress figure out the listen address for the DNS server
// first check the 53 port availability on WG interface or lo, if not success
@@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
}
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
addrPort := netip.AddrPortFrom(ip, uint16(port))
udpAddr := net.UDPAddrFromAddrPort(addrPort)
udpLn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
return false
}
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
if err := udpLn.Close(); err != nil {
log.Debugf("close UDP probe listener: %s", err)
}
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
return false
}
if err := tcpLn.Close(); err != nil {
log.Debugf("close TCP probe listener: %s", err)
}
return true
}

View File

@@ -0,0 +1,86 @@
package dns
import (
"fmt"
"net"
"net/netip"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("192.0.2.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// Create a service using a custom address to avoid needing root
svc := newServiceViaListener(nil, nil, nil)
svc.dnsMux.Handle(".", handler)
// Bind both transports up front to avoid TOCTOU races.
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
t.Skip("cannot bind to 127.0.0.153, skipping")
}
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
udpConn.Close()
t.Skip("cannot bind TCP on same port, skipping")
}
addr := fmt.Sprintf("%s:%d", customIP, port)
svc.server.PacketConn = udpConn
svc.tcpServer.Listener = tcpLn
svc.listenIP = customIP
svc.listenPort = port
go func() {
if err := svc.server.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := svc.tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
svc.listenerIsRunning = true
defer func() {
require.NoError(t, svc.Stop())
}()
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
// Test UDP query
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
udpResp, _, err := udpClient.Exchange(q, addr)
require.NoError(t, err, "UDP query should succeed")
require.NotNil(t, udpResp)
require.NotEmpty(t, udpResp.Answer)
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
// Test TCP query
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
tcpResp, _, err := tcpClient.Exchange(q, addr)
require.NoError(t, err, "TCP query should succeed")
require.NotNil(t, tcpResp)
require.NotEmpty(t, tcpResp.Answer)
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
}

View File

@@ -1,6 +1,7 @@
package dns
import (
"errors"
"fmt"
"net/netip"
"sync"
@@ -10,6 +11,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
@@ -18,7 +20,8 @@ type ServiceViaMemory struct {
dnsMux *dns.ServeMux
runtimeIP netip.Addr
runtimePort int
udpFilterHookID string
tcpDNS *tcpDNSServer
tcpHookSet bool
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
@@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{
return &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: lastIP,
runtimePort: DefaultPort,
}
return s
}
func (s *ServiceViaMemory) Listen() error {
@@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return fmt.Errorf("filter dns traffice: %w", err)
if err := s.filterDNSTraffic(); err != nil {
return fmt.Errorf("filter dns traffic: %w", err)
}
s.listenerIsRunning = true
@@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
func (s *ServiceViaMemory) Stop() {
func (s *ServiceViaMemory) Stop() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
return nil
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
filter := s.wgInterface.GetFilter()
if filter != nil {
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
if s.tcpHookSet {
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
}
}
if s.tcpDNS != nil {
s.tcpDNS.Stop()
}
s.listenerIsRunning = false
return nil
}
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
@@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
return s.runtimeIP
}
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
func (s *ServiceViaMemory) filterDNSTraffic() error {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
return errors.New("DNS filter not initialized")
}
// Create TCP DNS server lazily here since the device may not exist at construction time.
if s.tcpDNS == nil {
if dev := s.wgInterface.GetDevice(); dev != nil {
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
}
}
firstLayerDecoder := layers.LayerTypeIPv4
@@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
if udpLayer == nil {
return true
}
udp, ok := udpLayer.(*layers.UDP)
if !ok {
return true
}
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
@@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
dev := s.wgInterface.GetDevice()
if dev == nil {
return true
}
go s.dnsMux.ServeDNS(&writer, msg)
writer := &responseWriter{
remote: remoteAddrFromPacket(packet),
packet: packet,
device: dev.Device,
}
go s.dnsMux.ServeDNS(writer, msg)
return true
}
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
if s.tcpDNS != nil {
tcpHook := func(packetData []byte) bool {
s.tcpDNS.InjectPacket(packetData)
return true
}
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
s.tcpHookSet = true
}
return nil
}

View File

@@ -0,0 +1,444 @@
package dns
import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
dnsTCPReceiveWindow = 8192
dnsTCPMaxInFlight = 16
dnsTCPIdleTimeout = 30 * time.Second
dnsTCPReadTimeout = 5 * time.Second
)
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
// It is started lazily when a truncated DNS response is detected and shuts down
// after a period of inactivity to conserve resources.
type tcpDNSServer struct {
mu sync.Mutex
s *stack.Stack
ep *dnsEndpoint
mux *dns.ServeMux
tunDev tun.Device
ip netip.Addr
port uint16
mtu uint16
running bool
closed bool
timerID uint64
timer *time.Timer
}
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
return &tcpDNSServer{
mux: mux,
tunDev: tunDev,
ip: ip,
port: port,
mtu: mtu,
}
}
// InjectPacket ensures the stack is running and delivers a raw IP packet into
// the gvisor stack for TCP processing. Combining both operations under a single
// lock prevents a race where the idle timer could stop the stack between
// start and delivery.
func (t *tcpDNSServer) InjectPacket(payload []byte) {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return
}
if !t.running {
if err := t.startLocked(); err != nil {
log.Errorf("failed to start TCP DNS stack: %v", err)
return
}
t.running = true
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
}
t.resetTimerLocked()
ep := t.ep
if ep == nil || ep.dispatcher == nil {
return
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
// Stop tears down the gvisor stack and releases resources permanently.
// After Stop, InjectPacket becomes a no-op.
func (t *tcpDNSServer) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
t.stopLocked()
t.closed = true
}
func (t *tcpDNSServer) startLocked() error {
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
HandleLocal: false,
})
nicID := tcpip.NICID(1)
ep := &dnsEndpoint{
tunDev: t.tunDev,
}
ep.mtu.Store(uint32(t.mtu))
if err := s.CreateNIC(nicID, ep); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
PrefixLen: 32,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("add protocol address: %s", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set spoofing: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create default subnet: %w", err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: defaultSubnet, NIC: nicID},
})
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
t.handleTCPDNS(r)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
t.s = s
t.ep = ep
return nil
}
func (t *tcpDNSServer) stopLocked() {
if !t.running {
return
}
if t.timer != nil {
t.timer.Stop()
t.timer = nil
}
if t.s != nil {
t.s.Close()
t.s.Wait()
t.s = nil
}
t.ep = nil
t.running = false
log.Debugf("TCP DNS stack stopped")
}
func (t *tcpDNSServer) resetTimerLocked() {
if t.timer != nil {
t.timer.Stop()
}
t.timerID++
id := t.timerID
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
t.mu.Lock()
defer t.mu.Unlock()
// Only stop if this timer is still the active one.
// A racing InjectPacket may have replaced it.
if t.timerID != id {
return
}
t.stopLocked()
})
}
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
id := r.ID()
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
r.Complete(true)
return
}
r.Complete(false)
conn := gonet.NewTCPConn(&wq, ep)
defer func() {
if err := conn.Close(); err != nil {
log.Tracef("TCP DNS: close conn: %v", err)
}
}()
// Reset idle timer on activity
t.mu.Lock()
t.resetTimerLocked()
t.mu.Unlock()
localAddr := &net.TCPAddr{
IP: id.LocalAddress.AsSlice(),
Port: int(id.LocalPort),
}
remoteAddr := &net.TCPAddr{
IP: id.RemoteAddress.AsSlice(),
Port: int(id.RemotePort),
}
for {
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
break
}
msg, err := readTCPDNSMessage(conn)
if err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
}
break
}
writer := &tcpResponseWriter{
conn: conn,
localAddr: localAddr,
remoteAddr: remoteAddr,
}
t.mux.ServeDNS(writer, msg)
}
}
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
type dnsEndpoint struct {
dispatcher stack.NetworkDispatcher
tunDev tun.Device
mtu atomic.Uint32
}
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
func (e *dnsEndpoint) Wait() { /* no async work */ }
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
const tunPacketOffset = 40
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
raw := data.AsSlice()
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
buf = append(buf, raw...)
data.Release()
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
continue
}
written++
}
return written, nil
}
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
type tcpResponseWriter struct {
conn *gonet.TCPConn
localAddr net.Addr
remoteAddr net.Addr
}
func (w *tcpResponseWriter) LocalAddr() net.Addr {
return w.localAddr
}
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
return w.remoteAddr
}
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
data, err := msg.Pack()
if err != nil {
return fmt.Errorf("pack: %w", err)
}
// DNS TCP: 2-byte length prefix + message
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err = w.conn.Write(buf); err != nil {
return err
}
return nil
}
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err := w.conn.Write(buf); err != nil {
return 0, err
}
return len(data), nil
}
func (w *tcpResponseWriter) Close() error {
return w.conn.Close()
}
func (w *tcpResponseWriter) TsigStatus() error { return nil }
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
// DNS over TCP uses a 2-byte length prefix
lenBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return nil, fmt.Errorf("read length: %w", err)
}
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
if msgLen == 0 || msgLen > 65535 {
return nil, fmt.Errorf("invalid message length: %d", msgLen)
}
msgBuf := make([]byte, msgLen)
if _, err := io.ReadFull(conn, msgBuf); err != nil {
return nil, fmt.Errorf("read message: %w", err)
}
msg := new(dns.Msg)
if err := msg.Unpack(msgBuf); err != nil {
return nil, fmt.Errorf("unpack: %w", err)
}
return msg, nil
}
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
// Supports both IPv4 and IPv6.
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
if len(pkt) == 0 {
return netip.AddrPort{}
}
srcIP, transportOffset := srcIPFromPacket(pkt)
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
return netip.AddrPort{}
}
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
}
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
switch header.IPVersion(pkt) {
case 4:
return srcIPv4(pkt)
case 6:
return srcIPv6(pkt)
default:
return netip.Addr{}, 0
}
}
func srcIPv4(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv4MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv4(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, int(hdr.HeaderLength())
}
func srcIPv6(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv6MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv6(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, header.IPv6MinimumSize
}

View File

@@ -41,10 +41,61 @@ const (
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
const (
protoUDP = "udp"
protoTCP = "tcp"
)
type dnsProtocolKey struct{}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
}
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
func dnsProtocolFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
return v
}
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
func setUpstreamProtocol(ctx context.Context, protocol string) {
if ctx == nil {
return
}
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
r.protocol = protocol
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
@@ -138,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
ok, failures := u.tryUpstreamServers(w, r, logger)
// Propagate inbound protocol so upstream exchange can use TCP directly
// when the request came in over TCP.
ctx := u.ctx
if addr := w.RemoteAddr(); addr != nil {
network := addr.Network()
ctx = contextWithDNSProtocol(ctx, network)
resutil.SetMeta(w, "protocol", network)
}
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
@@ -153,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -168,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
@@ -178,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
@@ -203,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
@@ -220,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
}
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
@@ -428,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
return int(opt.UDPSize())
}
return dns.MinMsgSize
}
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
client.UDPSize = uint16(currentMTU - (60 + 8))
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
setUpstreamProtocol(ctx, protoTCP)
return rm, t, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than our read buffer.
// Note: the query could be sent out on an interface that is not ours,
// but higher MTU settings could break truncation handling.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
client.UDPSize = maxUDPPayload
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
@@ -453,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
}
if rm == nil || !rm.MsgHdr.Truncated {
setUpstreamProtocol(ctx, protoUDP)
return rm, t, nil
}
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// TODO: if the upstream's truncated UDP response already contains more
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
client.Net = "tcp"
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
return rm, t, nil
}
@@ -479,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
// If request came in over TCP, go straight to TCP upstream
if dnsProtocolFromContext(ctx) == protoTCP {
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
return rm, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than what we can read over UDP.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
if err != nil {
return nil, err
}
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
return rm, nil
}
setUpstreamProtocol(ctx, protoUDP)
return reply, nil
}
@@ -511,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
}
}
dnsConn := &dns.Conn{Conn: conn}
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)

View File

@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
upstreamExchangeClient := &dns.Client{
Timeout: ClientTimeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
}
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
Timeout: timeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {

View File

@@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) {
})
}
}
func TestDNSProtocolContext(t *testing.T) {
t.Run("roundtrip udp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
})
t.Run("roundtrip tcp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
})
t.Run("missing returns empty", func(t *testing.T) {
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
})
}
func TestExchangeWithFallback_TCPContext(t *testing.T) {
// Start a local DNS server that responds on TCP only
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpServer := &dns.Server{
Addr: "127.0.0.1:0",
Net: "tcp",
Handler: tcpHandler,
}
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer.Listener = tcpLn
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// With TCP context, should connect directly via TCP without trying UDP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
}
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
// UDP handler returns a truncated response to trigger TCP retry.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// TCP handler returns the full answer.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.3"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{
PacketConn: udpPC,
Net: "udp",
Handler: udpHandler,
}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := udpServer.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
}()
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
assert.False(t, rm.Truncated, "TCP response should not be truncated")
}
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
// Start only a TCP server (no UDP). With TCP context it should succeed.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.2"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// TCP context: should skip UDP entirely and go directly to TCP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
// Without TCP context, trying to reach a TCP-only server via UDP should fail
ctx2 := context.Background()
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
assert.Error(t, err, "should fail when no UDP server and no TCP context")
}
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
t.Cleanup(func() { _ = udpServer.Shutdown() })
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
"upstream should see capped EDNS0, not the client's 4096")
}
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
// When the client advertises a large EDNS0 (4096) and the upstream
// truncates, the TCP response should NOT be truncated since the full
// answer fits within the client's original buffer.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
// Add enough records to exceed MTU but fit within 4096
for i := range 20 {
m.Answer = append(m.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
})
}
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
go func() { _ = tcpServer.ActivateAndServe() }()
t.Cleanup(func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
})
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
// Client with large buffer: should get all records without truncation
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
// Client with small buffer: should get truncated response
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r2.SetEdns0(512, false)
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
require.NoError(t, err)
require.NotNil(t, rm2)
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}

View File

@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
@@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
f.handleDNSQuery(logger, w, query, startTime)
}

View File

@@ -46,6 +46,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
@@ -116,6 +117,7 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerRDPAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -196,6 +198,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
rdpServer rdpServer
statusRecorder *peer.Status
@@ -210,9 +213,10 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
relayManager *relayClient.Manager
stateManager *statemanager.Manager
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
@@ -259,26 +263,27 @@ func NewEngine(
mobileDep MobileDependency,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -500,7 +505,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetClientRoutes() {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
@@ -521,6 +526,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return err
}
// Inject firewall into DNS server now that it's available.
// The DNS server is created before the firewall because the route manager
// depends on the DNS server, and the firewall depends on the wg interface.
e.dnsServer.SetFirewall(e.firewall)
e.udpMux, err = e.wgInterface.Up()
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
@@ -532,6 +542,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Start after interface is up since port may have been resolved from 0 or changed if occupied
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
}()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -1021,6 +1038,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateRDP(); err != nil {
log.Warnf("failed handling RDP server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
@@ -1308,6 +1329,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
// Reuse SSH ACL for RDP authorization
e.updateRDPServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1535,12 +1559,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
MetricsRecorder: e.clientMetrics,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
PortForwardManager: e.portForwardManager,
MetricsRecorder: e.clientMetrics,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1697,6 +1722,12 @@ func (e *Engine) close() {
if e.rpManager != nil {
_ = e.rpManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -1800,7 +1831,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:
@@ -1837,6 +1868,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
return e.exposeManager
}
// IsBlockInbound returns whether inbound connections are blocked.
func (e *Engine) IsBlockInbound() bool {
return e.config.BlockInbound
}
// GetClientMetrics returns the client metrics
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics

View File

@@ -0,0 +1,191 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type rdpServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetPendingStore() *rdpserver.PendingStore
UpdateRDPAuth(config *sshauth.Config)
}
func (e *Engine) setupRDPPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
return fmt.Errorf("add RDP auth port redirection: %w", err)
}
log.Infof("RDP auth port redirection enabled: %s:%d -> %s:%d",
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
return nil
}
func (e *Engine) cleanupRDPPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
return fmt.Errorf("remove RDP auth port redirection: %w", err)
}
log.Debugf("RDP auth port redirection removed: %s:%d -> %s:%d",
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
return nil
}
// updateRDP handles starting/stopping the RDP server based on the config flag.
func (e *Engine) updateRDP() error {
if !e.config.ServerRDPAllowed {
if e.rdpServer != nil {
log.Info("RDP passthrough disabled, stopping RDP auth server")
}
return e.stopRDPServer()
}
if e.config.BlockInbound {
log.Info("RDP server is disabled because inbound connections are blocked")
return e.stopRDPServer()
}
if e.rdpServer != nil {
log.Debug("RDP auth server is already running")
return nil
}
return e.startRDPServer()
}
func (e *Engine) startRDPServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
wgAddr := e.wgInterface.Address()
cfg := &rdpserver.Config{
NetworkAddr: wgAddr.Network,
}
server := rdpserver.New(cfg)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, rdpserver.InternalRDPAuthPort)
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start RDP auth server: %w", err)
}
e.rdpServer = server
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
log.Debugf("registered RDP auth service with netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
}
}
if err := e.setupRDPPortRedirection(); err != nil {
log.Warnf("failed to setup RDP auth port redirection: %v", err)
}
// Register the credential provider DLL dynamically (Windows only)
if err := rdpserver.RegisterCredentialProvider(); err != nil {
log.Warnf("failed to register RDP credential provider (passwordless RDP will not work): %v", err)
}
log.Info("RDP passthrough enabled")
return nil
}
func (e *Engine) stopRDPServer() error {
if e.rdpServer == nil {
return nil
}
if err := e.cleanupRDPPortRedirection(); err != nil {
log.Warnf("failed to cleanup RDP auth port redirection: %v", err)
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
log.Debugf("unregistered RDP auth service from netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
}
}
// Unregister the credential provider DLL (Windows only)
if err := rdpserver.UnregisterCredentialProvider(); err != nil {
log.Warnf("failed to unregister RDP credential provider: %v", err)
}
log.Info("stopping RDP auth server")
err := e.rdpServer.Stop()
e.rdpServer = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
// updateRDPServerAuth reuses the SSH authorization config for RDP access control.
// This means the same user/machine-user mappings that control SSH access also control RDP.
func (e *Engine) updateRDPServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil || e.rdpServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping RDP server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.rdpServer.UpdateRDPAuth(authConfig)
}

View File

@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err
}
publicKey, err := mgmtClient.GetServerPublicKey()
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,

View File

@@ -22,6 +22,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -45,6 +46,7 @@ type ServiceDependencies struct {
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
PeerConnDispatcher *dispatcher.ConnectionDispatcher
PortForwardManager *portforward.Manager
MetricsRecorder MetricsRecorder
}
@@ -87,16 +89,17 @@ type ConnConfig struct {
}
type Conn struct {
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
portForwardManager *portforward.Manager
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
@@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
portForwardManager: services.PortForwardManager,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
}
return conn, nil

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)
@@ -61,6 +62,9 @@ type WorkerICE struct {
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
// portForwardAttempted tracks if we've already tried port forwarding this session
portForwardAttempted bool
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
@@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
w.portForwardAttempted = false
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
@@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if candidate.Type() == ice.CandidateTypeServerReflexive {
w.injectPortForwardedCandidate(candidate)
}
}
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
pfManager := w.conn.portForwardManager
if pfManager == nil {
return
}
mapping := pfManager.GetMapping()
if mapping == nil {
return
}
w.muxAgent.Lock()
if w.portForwardAttempted {
w.muxAgent.Unlock()
return
}
w.portForwardAttempted = true
w.muxAgent.Unlock()
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
if err != nil {
w.log.Warnf("create forwarded candidate: %v", err)
return
}
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
go func() {
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
w.log.Errorf("signal port-forwarded candidate: %v", err)
}
}()
}
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
// It uses the NAT gateway's external IP with the forwarded port.
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
var externalIP string
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
externalIP = mapping.ExternalIP.String()
} else {
// Fallback to STUN-discovered address if NAT didn't provide external IP
externalIP = srflxCandidate.Address()
}
// Per RFC 8445, the related address for srflx is the base (host candidate address).
// If the original srflx has unspecified related address, use its own address as base.
relAddr := srflxCandidate.RelatedAddress().Address
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
relAddr = srflxCandidate.Address()
}
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
// over regular srflx during ICE connectivity checks.
priority := srflxCandidate.Priority() + 1000
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: srflxCandidate.NetworkType().String(),
Address: externalIP,
Port: int(mapping.ExternalPort),
Component: srflxCandidate.Component(),
Priority: priority,
RelAddr: relAddr,
RelPort: int(mapping.InternalPort),
})
if err != nil {
return nil, fmt.Errorf("create candidate: %w", err)
}
for _, e := range srflxCandidate.Extensions() {
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = srflxCandidate.ID()
}
if err := candidate.AddExtension(e); err != nil {
return nil, fmt.Errorf("add extension: %w", err)
}
}
return candidate, nil
}
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
@@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
local.NetworkType(), local.Type(), local.Address(), local.Port(),
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
stat.CurrentRoundTripTime*1000)
}
}

View File

@@ -0,0 +1,26 @@
package portforward
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
)
func isDisabledByEnv() bool {
val := os.Getenv(envDisableNATMapper)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
return false
}
return disabled
}

View File

@@ -0,0 +1,280 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"regexp"
"sync"
"time"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
)
const (
defaultMappingTTL = 2 * time.Hour
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
// allowing for whitespace/newlines between tags from different router firmware.
var upnpErrPermanentLeaseOnly = regexp.MustCompile(`<errorCode>\s*725\s*</errorCode>`)
// Mapping represents an active NAT port mapping.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
// TTL is the lease duration. Zero means a permanent lease that never expires.
TTL time.Duration
}
// TODO: persist mapping state for crash recovery cleanup of permanent leases.
// Currently not done because State.Cleanup requires NAT gateway re-discovery,
// which blocks startup for ~10s when no gateway is present (affects all clients).
type Manager struct {
cancel context.CancelFunc
mapping *Mapping
mappingLock sync.Mutex
wgPort uint16
done chan struct{}
stopCtx chan context.Context
// protect exported functions
mu sync.Mutex
}
// NewManager creates a new port forwarding manager.
func NewManager() *Manager {
return &Manager{
stopCtx: make(chan context.Context, 1),
}
}
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
m.mu.Lock()
if m.cancel != nil {
m.mu.Unlock()
return
}
if isDisabledByEnv() {
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
m.mu.Unlock()
return
}
if wgPort == 0 {
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
m.mu.Unlock()
return
}
m.wgPort = wgPort
m.done = make(chan struct{})
defer close(m.done)
ctx, m.cancel = context.WithCancel(ctx)
m.mu.Unlock()
gateway, mapping, err := m.setup(ctx)
if err != nil {
log.Infof("port forwarding setup: %v", err)
return
}
m.mappingLock.Lock()
m.mapping = mapping
m.mappingLock.Unlock()
m.renewLoop(ctx, gateway, mapping.TTL)
select {
case cleanupCtx := <-m.stopCtx:
// block the Start while cleaned up gracefully
m.cleanup(cleanupCtx, gateway)
default:
// return Start immediately and cleanup in background
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
defer cleanupCancel()
m.cleanup(cleanupCtx, gateway)
}()
}
}
// GetMapping returns the current mapping if ready, nil otherwise
func (m *Manager) GetMapping() *Mapping {
m.mappingLock.Lock()
defer m.mappingLock.Unlock()
if m.mapping == nil {
return nil
}
mapping := *m.mapping
return &mapping
}
// GracefullyStop cancels the manager and attempts to delete the port mapping.
// After GracefullyStop returns, the manager cannot be restarted.
func (m *Manager) GracefullyStop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil {
return nil
}
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
m.startTearDown(ctx)
m.cancel()
m.cancel = nil
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
}
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := nat.DiscoverGateway(discoverCtx)
if err != nil {
return nil, nil, fmt.Errorf("discover gateway: %w", err)
}
log.Infof("discovered NAT gateway: %s", gateway.Type())
mapping, err := m.createMapping(ctx, gateway)
if err != nil {
return nil, nil, fmt.Errorf("create port mapping: %w", err)
}
return gateway, mapping, nil
}
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
ttl := defaultMappingTTL
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
if err != nil {
if !isPermanentLeaseRequired(err) {
return nil, err
}
log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
ttl = 0
externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
if err != nil {
return nil, err
}
}
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
// todo return with err?
}
mapping := &Mapping{
Protocol: "udp",
InternalPort: m.wgPort,
ExternalPort: uint16(externalPort),
ExternalIP: externalIP,
NATType: gateway.Type(),
TTL: ttl,
}
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
m.wgPort, externalPort, gateway.Type(), externalIP)
return mapping, nil
}
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
if ttl == 0 {
// Permanent mappings don't expire, just wait for cancellation.
<-ctx.Done()
return
}
ticker := time.NewTicker(ttl / 2)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
}
}
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
if err != nil {
return fmt.Errorf("add port mapping: %w", err)
}
if uint16(externalPort) != m.mapping.ExternalPort {
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
m.mappingLock.Lock()
m.mapping.ExternalPort = uint16(externalPort)
m.mappingLock.Unlock()
}
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
return nil
}
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
m.mappingLock.Lock()
mapping := m.mapping
m.mapping = nil
m.mappingLock.Unlock()
if mapping == nil {
return
}
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
log.Warnf("delete port mapping on stop: %v", err)
return
}
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
}
func (m *Manager) startTearDown(ctx context.Context) {
select {
case m.stopCtx <- ctx:
default:
}
}
// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
func isPermanentLeaseRequired(err error) bool {
return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
}

View File

@@ -0,0 +1,39 @@
package portforward
import (
"context"
"net"
"time"
)
// Mapping represents an active NAT port mapping.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
// TTL is the lease duration. Zero means a permanent lease that never expires.
TTL time.Duration
}
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
type Manager struct{}
// NewManager returns a stub manager for js/wasm builds.
func NewManager() *Manager {
return &Manager{}
}
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
func (m *Manager) Start(context.Context, uint16) {
// no NAT traversal in wasm
}
// GracefullyStop is a no-op on js/wasm.
func (m *Manager) GracefullyStop(context.Context) error { return nil }
// GetMapping always returns nil on js/wasm.
func (m *Manager) GetMapping() *Mapping {
return nil
}

View File

@@ -0,0 +1,201 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockNAT struct {
natType string
deviceAddr net.IP
externalAddr net.IP
internalAddr net.IP
mappings map[int]int
addMappingErr error
deleteMappingErr error
onlyPermanentLeases bool
lastTimeout time.Duration
}
func newMockNAT() *mockNAT {
return &mockNAT{
natType: "Mock-NAT",
deviceAddr: net.ParseIP("192.168.1.1"),
externalAddr: net.ParseIP("203.0.113.50"),
internalAddr: net.ParseIP("192.168.1.100"),
mappings: make(map[int]int),
}
}
func (m *mockNAT) Type() string {
return m.natType
}
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
return m.deviceAddr, nil
}
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
return m.externalAddr, nil
}
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
return m.internalAddr, nil
}
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
if m.addMappingErr != nil {
return 0, m.addMappingErr
}
if m.onlyPermanentLeases && timeout != 0 {
return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: <UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\"><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>")
}
externalPort := internalPort
m.mappings[internalPort] = externalPort
m.lastTimeout = timeout
return externalPort, nil
}
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if m.deleteMappingErr != nil {
return m.deleteMappingErr
}
delete(m.mappings, internalPort)
return nil
}
func TestManager_CreateMapping(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, "udp", mapping.Protocol)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, uint16(51820), mapping.ExternalPort)
assert.Equal(t, "Mock-NAT", mapping.NATType)
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
assert.Equal(t, defaultMappingTTL, mapping.TTL)
}
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
m := NewManager()
assert.Nil(t, m.GetMapping())
}
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
mapping := m.GetMapping()
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
// Mutating the returned copy should not affect the manager's mapping.
mapping.ExternalPort = 9999
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
}
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
gateway := newMockNAT()
// Seed the mock so we can verify deletion.
gateway.mappings[51820] = 51820
m.cleanup(context.Background(), gateway)
_, exists := gateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted from gateway")
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
}
func TestManager_Cleanup_NilMapping(t *testing.T) {
m := NewManager()
gateway := newMockNAT()
// Should not panic or call gateway.
m.cleanup(context.Background(), gateway)
}
func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
gateway.onlyPermanentLeases = true
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
}
func TestIsPermanentLeaseRequired(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "UPnP error 725",
err: fmt.Errorf("SOAP fault. Code: | Detail: <UPnPError><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>"),
expected: true,
},
{
name: "wrapped error with 725",
err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: <errorCode>725</errorCode>")),
expected: true,
},
{
name: "error 725 with newlines in XML",
err: fmt.Errorf("<errorCode>\n 725\n</errorCode>"),
expected: true,
},
{
name: "bare 725 without XML tag",
err: fmt.Errorf("error code 725"),
expected: false,
},
{
name: "unrelated error",
err: fmt.Errorf("connection refused"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
})
}
}

View File

@@ -41,7 +41,7 @@ const (
// mgmProber is the subset of management client needed for URL migration probes.
type mgmProber interface {
GetServerPublicKey() (*wgtypes.Key, error)
HealthCheck() error
Close() error
}
@@ -64,6 +64,7 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerRDPAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -114,6 +115,7 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerRDPAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -415,6 +417,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerRDPAllowed != nil {
if config.ServerRDPAllowed == nil || *input.ServerRDPAllowed != *config.ServerRDPAllowed {
if *input.ServerRDPAllowed {
log.Infof("enabling RDP passthrough")
} else {
log.Infof("disabling RDP passthrough")
}
config.ServerRDPAllowed = input.ServerRDPAllowed
updated = true
}
} else if config.ServerRDPAllowed == nil {
config.ServerRDPAllowed = util.False()
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
@@ -777,8 +794,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
}()
// gRPC check
_, err = client.GetServerPublicKey()
if err != nil {
if err = client.HealthCheck(); err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}

View File

@@ -17,12 +17,10 @@ import (
"github.com/netbirdio/netbird/util"
)
type mockMgmProber struct {
key wgtypes.Key
}
type mockMgmProber struct{}
func (m *mockMgmProber) GetServerPublicKey() (*wgtypes.Key, error) {
return &m.key, nil
func (m *mockMgmProber) HealthCheck() error {
return nil
}
func (m *mockMgmProber) Close() error { return nil }
@@ -247,11 +245,7 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
func TestUpdateOldManagementURL(t *testing.T) {
origProber := newMgmProber
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
key, err := wgtypes.GenerateKey()
if err != nil {
return nil, err
}
return &mockMgmProber{key: key.PublicKey()}, nil
return &mockMgmProber{}, nil
}
t.Cleanup(func() { newMgmProber = origProber })

View File

@@ -52,6 +52,7 @@ type Manager interface {
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -465,6 +466,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
return maps.Clone(m.clientRoutes)
}
// GetSelectedClientRoutes returns only the currently selected/active client routes,
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
// if traffic should be routed through the tunnel.
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
m.mux.Lock()
defer m.mux.Unlock()
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()

View File

@@ -18,6 +18,7 @@ type MockManager struct {
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
return nil
}
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
func (m *MockManager) GetClientRoutes() route.HAMap {
if m.GetClientRoutesFunc != nil {
return m.GetClientRoutesFunc()
@@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
return nil
}
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
if m.GetSelectedClientRoutesFunc != nil {
return m.GetSelectedClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -53,7 +53,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
n.currentPrefixes = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()

View File

@@ -161,7 +161,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -472,6 +472,7 @@ type LoginRequest struct {
EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed *bool `protobuf:"varint,40,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -780,6 +781,13 @@ func (x *LoginRequest) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *LoginRequest) GetServerRDPAllowed() bool {
if x != nil && x.ServerRDPAllowed != nil {
return *x.ServerRDPAllowed
}
return false
}
type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -1312,6 +1320,7 @@ type GetConfigResponse struct {
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed bool `protobuf:"varint,27,opt,name=serverRDPAllowed,proto3" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1528,6 +1537,13 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *GetConfigResponse) GetServerRDPAllowed() bool {
if x != nil {
return x.ServerRDPAllowed
}
return false
}
// PeerState contains the latest state of a peer
type PeerState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -4139,6 +4155,7 @@ type SetConfigRequest struct {
EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed *bool `protobuf:"varint,35,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4411,6 +4428,13 @@ func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *SetConfigRequest) GetServerRDPAllowed() bool {
if x != nil && x.ServerRDPAllowed != nil {
return *x.ServerRDPAllowed
}
return false
}
type SetConfigResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields

View File

@@ -209,6 +209,8 @@ message LoginRequest {
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool serverRDPAllowed = 40;
}
message LoginResponse {
@@ -316,6 +318,8 @@ message GetConfigResponse {
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
bool serverRDPAllowed = 27;
}
// PeerState contains the latest state of a peer
@@ -677,6 +681,8 @@ message SetConfigRequest {
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool serverRDPAllowed = 35;
}
message SetConfigResponse{}

View File

@@ -0,0 +1,88 @@
package client
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"time"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
)
const (
// DefaultTimeout is the default timeout for sideband auth requests.
DefaultTimeout = 30 * time.Second
// maxResponseSize is the maximum size of an auth response in bytes.
maxResponseSize = 64 * 1024
)
// Client connects to a target peer's RDP sideband auth server to request access.
type Client struct {
Timeout time.Duration
}
// New creates a new sideband RDP auth client.
func New() *Client {
return &Client{
Timeout: DefaultTimeout,
}
}
// RequestAuth sends an authorization request to the target peer's sideband server
// and returns the response. The addr should be in "host:port" format.
func (c *Client) RequestAuth(ctx context.Context, addr string, req *rdpserver.AuthRequest) (*rdpserver.AuthResponse, error) {
timeout := c.Timeout
if timeout <= 0 {
timeout = DefaultTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("connect to RDP auth server at %s: %w", addr, err)
}
defer func() { _ = conn.Close() }()
deadline, ok := ctx.Deadline()
if ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set connection deadline: %w", err)
}
}
// Send request
reqData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal auth request: %w", err)
}
if _, err := conn.Write(reqData); err != nil {
return nil, fmt.Errorf("send auth request: %w", err)
}
// Signal we're done writing so the server can read the full request
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := tcpConn.CloseWrite(); err != nil {
return nil, fmt.Errorf("close write: %w", err)
}
}
// Read response
respData, err := io.ReadAll(io.LimitReader(conn, maxResponseSize))
if err != nil {
return nil, fmt.Errorf("read auth response: %w", err)
}
var resp rdpserver.AuthResponse
if err := json.Unmarshal(respData, &resp); err != nil {
return nil, fmt.Errorf("unmarshal auth response: %w", err)
}
return &resp, nil
}

View File

@@ -0,0 +1,31 @@
[package]
name = "netbird-credprov"
version = "0.1.0"
edition = "2021"
description = "NetBird RDP Credential Provider for Windows"
license = "BSD-3-Clause"
[lib]
crate-type = ["cdylib"]
[dependencies]
windows = { version = "0.58", features = [
"implement",
"Win32_Foundation",
"Win32_System_Com",
"Win32_UI_Shell",
"Win32_Security",
"Win32_Security_Authentication_Identity",
"Win32_Security_Credentials",
"Win32_System_RemoteDesktop",
"Win32_System_Threading",
] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
uuid = { version = "1", features = ["v4"] }
log = "0.4"
[profile.release]
opt-level = "s"
lto = true
strip = true

View File

@@ -0,0 +1,210 @@
//! ICredentialProviderCredential implementation.
//!
//! Represents a single "NetBird Login" credential tile on the Windows login screen.
//! When selected, it queries the local NetBird agent for pending RDP sessions and
//! performs S4U logon to authenticate the user without a password.
use crate::named_pipe_client::{NamedPipeClient, PipeResponse};
use crate::s4u;
use std::sync::Mutex;
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::Security::Credentials::*;
use windows::Win32::UI::Shell::*;
/// NetBird credential tile that appears on the Windows login screen.
#[implement(ICredentialProviderCredential)]
pub struct NetBirdCredential {
/// The pending session information from the NetBird agent.
session: Mutex<Option<PipeResponse>>,
/// The remote IP address of the connecting peer.
remote_ip: Mutex<String>,
}
impl NetBirdCredential {
pub fn new(remote_ip: String, session: PipeResponse) -> Self {
Self {
session: Mutex::new(Some(session)),
remote_ip: Mutex::new(remote_ip),
}
}
}
impl ICredentialProviderCredential_Impl for NetBirdCredential_Impl {
fn Advise(&self, _pcpce: Option<&ICredentialProviderCredentialEvents>) -> Result<()> {
Ok(())
}
fn UnAdvise(&self) -> Result<()> {
Ok(())
}
fn SetSelected(&self, _pbautologon: *mut BOOL) -> Result<()> {
// Auto-logon when this credential is selected
unsafe {
if !_pbautologon.is_null() {
*_pbautologon = TRUE;
}
}
Ok(())
}
fn SetDeselected(&self) -> Result<()> {
Ok(())
}
fn GetFieldState(
&self,
_dwfieldid: u32,
_pcpfs: *mut CREDENTIAL_PROVIDER_FIELD_STATE,
_pcpfis: *mut CREDENTIAL_PROVIDER_FIELD_INTERACTIVE_STATE,
) -> Result<()> {
// We have a single display-only field showing "NetBird Login"
unsafe {
if !_pcpfs.is_null() {
*_pcpfs = CPFS_DISPLAY_IN_SELECTED_TILE;
}
if !_pcpfis.is_null() {
*_pcpfis = CPFIS_NONE;
}
}
Ok(())
}
fn GetStringValue(&self, _dwfieldid: u32) -> Result<PWSTR> {
let session = self.session.lock().unwrap();
let text = if let Some(ref s) = *session {
format!("NetBird: Logging in as {}", s.os_user)
} else {
"NetBird Login".to_string()
};
let wide: Vec<u16> = text.encode_utf16().chain(std::iter::once(0)).collect();
let ptr = unsafe {
let mem = windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
if mem.is_null() {
return Err(E_OUTOFMEMORY.into());
}
std::ptr::copy_nonoverlapping(wide.as_ptr(), mem, wide.len());
PWSTR(mem)
};
Ok(ptr)
}
fn GetBitmapValue(&self, _dwfieldid: u32) -> Result<HBITMAP> {
Err(E_NOTIMPL.into())
}
fn GetCheckboxValue(&self, _dwfieldid: u32, _pbchecked: *mut BOOL, _ppszlabel: *mut PWSTR) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetSubmitButtonValue(&self, _dwfieldid: u32, _pdwadjacentto: *mut u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetComboBoxValueCount(&self, _dwfieldid: u32, _pcitems: *mut u32, _pdwselecteditem: *mut u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetComboBoxValueAt(&self, _dwfieldid: u32, _dwitem: u32) -> Result<PWSTR> {
Err(E_NOTIMPL.into())
}
fn SetStringValue(&self, _dwfieldid: u32, _psz: &PCWSTR) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn SetCheckboxValue(&self, _dwfieldid: u32, _bchecked: BOOL) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn SetComboBoxSelectedValue(&self, _dwfieldid: u32, _dwselecteditem: u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn CommandLinkClicked(&self, _dwfieldid: u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetSerialization(
&self,
_pcpgsr: *mut CREDENTIAL_PROVIDER_GET_SERIALIZATION_RESPONSE,
_pcpcs: *mut CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
_ppszoptionalstatustext: *mut PWSTR,
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
) -> Result<()> {
let session = self.session.lock().unwrap();
let session_info = match &*session {
Some(s) => s.clone(),
None => {
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_NOT_FINISHED;
}
return Ok(());
}
};
// Consume the session with the agent
if let Err(e) = NamedPipeClient::consume_session(&session_info.session_id) {
log::error!("Failed to consume RDP session: {}", e);
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
}
return Ok(());
}
// Perform S4U logon
let username = &session_info.os_user;
let domain = if session_info.domain.is_empty() {
"."
} else {
&session_info.domain
};
match s4u::generate_s4u_token(username, domain) {
Ok(_token) => {
// In a full implementation, we would serialize the token into
// CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION format
// (KerbInteractiveLogon or MsV1_0InteractiveLogon structure).
//
// For the POC, we signal success. The actual serialization requires
// building the proper KERB_INTERACTIVE_LOGON or MSV1_0_INTERACTIVE_LOGON
// structure with the token handle, which is complex.
//
// TODO: Build proper credential serialization from S4U token
log::info!(
"S4U logon successful for {}\\{}, session {}",
domain,
username,
session_info.session_id
);
unsafe {
*_pcpgsr = CPGSR_RETURN_CREDENTIAL_FINISHED;
// Note: In production, pcpcs would be filled with the serialized credentials
}
Ok(())
}
Err(e) => {
log::error!("S4U logon failed for {}\\{}: {}", domain, username, e);
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
}
Ok(())
}
}
}
fn ReportResult(
&self,
_ntstatus: NTSTATUS,
_ntssubstatus: NTSTATUS,
_ppszoptionalstatustext: *mut PWSTR,
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
) -> Result<()> {
Ok(())
}
}

View File

@@ -0,0 +1,11 @@
use windows::core::GUID;
/// CLSID for the NetBird RDP Credential Provider.
/// Generated UUID: {7B3A8E5F-1C4D-4F8A-B2E6-9D0F3A7C5E1B}
pub const CLSID_NETBIRD_CREDENTIAL_PROVIDER: GUID = GUID::from_u128(
0x7B3A8E5F_1C4D_4F8A_B2E6_9D0F3A7C5E1B,
);
/// Registry path for credential providers.
pub const CREDENTIAL_PROVIDER_REGISTRY_PATH: &str =
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers";

View File

@@ -0,0 +1,309 @@
//! NetBird RDP Credential Provider for Windows.
//!
//! This DLL is a Windows Credential Provider that enables passwordless RDP access
//! to machines running the NetBird agent. It is loaded by Windows' LogonUI.exe
//! via COM when the login screen is displayed.
//!
//! ## How it works
//!
//! 1. The DLL is registered as a Credential Provider in the Windows registry
//! 2. When an RDP session begins, LogonUI loads the DLL
//! 3. The DLL queries the local NetBird agent via named pipe for pending sessions
//! 4. If a pending session exists for the connecting peer, the DLL:
//! - Shows a "NetBird Login" credential tile
//! - Performs S4U logon to create a Windows token without a password
//! - Returns the token to LogonUI for session creation
mod credential;
mod guid;
mod named_pipe_client;
mod provider;
mod s4u;
use guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
use provider::NetBirdCredentialProvider;
use std::sync::atomic::{AtomicU32, Ordering};
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::System::Com::*;
/// DLL reference count for COM lifecycle management.
static DLL_REF_COUNT: AtomicU32 = AtomicU32::new(0);
/// DLL module handle.
static mut DLL_MODULE: HMODULE = HMODULE(std::ptr::null_mut());
/// COM class factory for creating NetBirdCredentialProvider instances.
#[implement(IClassFactory)]
struct NetBirdClassFactory;
impl IClassFactory_Impl for NetBirdClassFactory_Impl {
fn CreateInstance(
&self,
_punkouter: Option<&IUnknown>,
riid: *const GUID,
ppvobject: *mut *mut std::ffi::c_void,
) -> Result<()> {
unsafe {
if !ppvobject.is_null() {
*ppvobject = std::ptr::null_mut();
}
}
if _punkouter.is_some() {
return Err(CLASS_E_NOAGGREGATION.into());
}
let provider = NetBirdCredentialProvider::new();
let unknown: IUnknown = provider.into();
unsafe {
unknown.query(riid, ppvobject).ok()
}
}
fn LockServer(&self, flock: BOOL) -> Result<()> {
if flock.as_bool() {
DLL_REF_COUNT.fetch_add(1, Ordering::SeqCst);
} else {
DLL_REF_COUNT.fetch_sub(1, Ordering::SeqCst);
}
Ok(())
}
}
/// DLL entry point.
#[no_mangle]
extern "system" fn DllMain(hinstance: HMODULE, reason: u32, _reserved: *mut std::ffi::c_void) -> BOOL {
const DLL_PROCESS_ATTACH: u32 = 1;
if reason == DLL_PROCESS_ATTACH {
unsafe {
DLL_MODULE = hinstance;
}
}
TRUE
}
/// COM entry point: returns a class factory for the requested CLSID.
#[no_mangle]
extern "system" fn DllGetClassObject(
rclsid: *const GUID,
riid: *const GUID,
ppv: *mut *mut std::ffi::c_void,
) -> HRESULT {
unsafe {
if ppv.is_null() {
return E_POINTER;
}
*ppv = std::ptr::null_mut();
if *rclsid != CLSID_NETBIRD_CREDENTIAL_PROVIDER {
return CLASS_E_CLASSNOTAVAILABLE;
}
let factory = NetBirdClassFactory;
let unknown: IUnknown = factory.into();
match unknown.query(riid, ppv) {
Ok(()) => S_OK,
Err(e) => e.code(),
}
}
}
/// COM entry point: indicates whether the DLL can be unloaded.
#[no_mangle]
extern "system" fn DllCanUnloadNow() -> HRESULT {
if DLL_REF_COUNT.load(Ordering::SeqCst) == 0 {
S_OK
} else {
S_FALSE
}
}
/// Self-registration: called by regsvr32 to register the credential provider.
#[no_mangle]
extern "system" fn DllRegisterServer() -> HRESULT {
match register_credential_provider(true) {
Ok(()) => S_OK,
Err(_) => E_FAIL,
}
}
/// Self-unregistration: called by regsvr32 /u to unregister the credential provider.
#[no_mangle]
extern "system" fn DllUnregisterServer() -> HRESULT {
match register_credential_provider(false) {
Ok(()) => S_OK,
Err(_) => E_FAIL,
}
}
fn register_credential_provider(register: bool) -> std::result::Result<(), Box<dyn std::error::Error>> {
use windows::Win32::System::Registry::*;
let clsid_str = format!("{{{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}",
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data1,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data2,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data3,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[0],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[1],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[2],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[3],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[4],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[5],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[6],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[7],
);
if register {
// Register under Credential Providers
let cp_key_path = format!(
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
clsid_str
);
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
let mut hkey = HKEY::default();
unsafe {
let result = RegCreateKeyExW(
HKEY_LOCAL_MACHINE,
PCWSTR(cp_key_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create credential provider registry key".into());
}
let value: Vec<u16> = "NetBird RDP Credential Provider"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let _ = RegSetValueExW(
hkey,
PCWSTR::null(),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
value.as_ptr() as *const u8,
value.len() * 2,
)),
);
let _ = RegCloseKey(hkey);
}
// Register CLSID in CLSID hive
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
let clsid_key_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let result = RegCreateKeyExW(
HKEY_CLASSES_ROOT,
PCWSTR(clsid_key_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create CLSID registry key".into());
}
let _ = RegCloseKey(hkey);
// InprocServer32 subkey
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
let result = RegCreateKeyExW(
HKEY_CLASSES_ROOT,
PCWSTR(inproc_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create InprocServer32 registry key".into());
}
// Set DLL path
let mut dll_path = [0u16; 260];
let len = windows::Win32::System::LibraryLoader::GetModuleFileNameW(
DLL_MODULE,
&mut dll_path,
);
if len > 0 {
let _ = RegSetValueExW(
hkey,
PCWSTR::null(),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
dll_path.as_ptr() as *const u8,
(len as usize + 1) * 2,
)),
);
}
// Set threading model
let threading: Vec<u16> = "Apartment"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let threading_name: Vec<u16> = "ThreadingModel"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let _ = RegSetValueExW(
hkey,
PCWSTR(threading_name.as_ptr()),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
threading.as_ptr() as *const u8,
threading.len() * 2,
)),
);
let _ = RegCloseKey(hkey);
}
} else {
// Unregister
let cp_key_path = format!(
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
clsid_str
);
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let _ = RegDeleteKeyW(HKEY_LOCAL_MACHINE, PCWSTR(cp_key_wide.as_ptr()));
}
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
let clsid_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(inproc_wide.as_ptr()));
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(clsid_wide.as_ptr()));
}
}
Ok(())
}

View File

@@ -0,0 +1,135 @@
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::time::Duration;
/// Named pipe path for communicating with the NetBird agent.
const PIPE_NAME: &str = r"\\.\pipe\netbird-rdp-auth";
/// Maximum response size from the agent.
const MAX_RESPONSE_SIZE: usize = 4096;
/// Timeout for named pipe operations.
const PIPE_TIMEOUT: Duration = Duration::from_secs(5);
/// Request sent to the NetBird agent via named pipe.
#[derive(Serialize)]
pub struct PipeRequest {
pub action: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub remote_ip: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
}
/// Response received from the NetBird agent via named pipe.
#[derive(Deserialize, Debug, Clone)]
pub struct PipeResponse {
pub found: bool,
#[serde(default)]
pub session_id: String,
#[serde(default)]
pub os_user: String,
#[serde(default)]
pub domain: String,
}
/// Client for communicating with the NetBird agent's named pipe server.
pub struct NamedPipeClient;
impl NamedPipeClient {
/// Query the NetBird agent for a pending RDP session matching the given remote IP.
pub fn query_pending(remote_ip: &str) -> Result<PipeResponse, PipeError> {
let request = PipeRequest {
action: "query_pending".to_string(),
remote_ip: Some(remote_ip.to_string()),
session_id: None,
};
Self::send_request(&request)
}
/// Tell the NetBird agent to consume (mark as used) a pending session.
pub fn consume_session(session_id: &str) -> Result<PipeResponse, PipeError> {
let request = PipeRequest {
action: "consume".to_string(),
remote_ip: None,
session_id: Some(session_id.to_string()),
};
Self::send_request(&request)
}
fn send_request(request: &PipeRequest) -> Result<PipeResponse, PipeError> {
let request_data =
serde_json::to_vec(request).map_err(|e| PipeError::Serialization(e.to_string()))?;
// Open named pipe (CreateFile in Windows)
let mut pipe = Self::open_pipe()?;
// Write request
pipe.write_all(&request_data)
.map_err(|e| PipeError::Write(e.to_string()))?;
// Shutdown write side to signal end of request
// For named pipes on Windows, we rely on the message boundary
pipe.flush()
.map_err(|e| PipeError::Write(e.to_string()))?;
// Read response
let mut response_data = vec![0u8; MAX_RESPONSE_SIZE];
let n = pipe
.read(&mut response_data)
.map_err(|e| PipeError::Read(e.to_string()))?;
let response: PipeResponse = serde_json::from_slice(&response_data[..n])
.map_err(|e| PipeError::Deserialization(e.to_string()))?;
Ok(response)
}
fn open_pipe() -> Result<std::fs::File, PipeError> {
// On Windows, named pipes are opened like files
use std::fs::OpenOptions;
// Try to open the pipe with a brief retry for PIPE_BUSY
for attempt in 0..3 {
match OpenOptions::new().read(true).write(true).open(PIPE_NAME) {
Ok(file) => return Ok(file),
Err(e) => {
if attempt < 2 {
std::thread::sleep(Duration::from_millis(100));
continue;
}
return Err(PipeError::Connect(format!(
"failed to open pipe {}: {}",
PIPE_NAME, e
)));
}
}
}
Err(PipeError::Connect("exhausted pipe connection attempts".to_string()))
}
}
/// Errors that can occur during named pipe communication.
#[derive(Debug)]
pub enum PipeError {
Connect(String),
Write(String),
Read(String),
Serialization(String),
Deserialization(String),
}
impl std::fmt::Display for PipeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PipeError::Connect(e) => write!(f, "pipe connect: {}", e),
PipeError::Write(e) => write!(f, "pipe write: {}", e),
PipeError::Read(e) => write!(f, "pipe read: {}", e),
PipeError::Serialization(e) => write!(f, "pipe serialization: {}", e),
PipeError::Deserialization(e) => write!(f, "pipe deserialization: {}", e),
}
}
}
impl std::error::Error for PipeError {}

View File

@@ -0,0 +1,270 @@
//! ICredentialProvider implementation.
//!
//! This is the main COM object that Windows' LogonUI.exe instantiates.
//! It determines whether to show a "NetBird Login" credential tile based on
//! whether the NetBird agent has a pending RDP session for the connecting peer.
use crate::credential::NetBirdCredential;
use crate::guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
use crate::named_pipe_client::NamedPipeClient;
use std::sync::Mutex;
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::Security::Credentials::*;
use windows::Win32::System::RemoteDesktop::*;
/// The NetBird Credential Provider, loaded by LogonUI.exe via COM.
#[implement(ICredentialProvider)]
pub struct NetBirdCredentialProvider {
/// The credential tile (if a pending session was found).
credential: Mutex<Option<ICredentialProviderCredential>>,
/// Whether this provider is active for the current usage scenario.
active: Mutex<bool>,
}
impl NetBirdCredentialProvider {
pub fn new() -> Self {
Self {
credential: Mutex::new(None),
active: Mutex::new(false),
}
}
}
impl ICredentialProvider_Impl for NetBirdCredentialProvider_Impl {
fn SetUsageScenario(
&self,
cpus: CREDENTIAL_PROVIDER_USAGE_SCENARIO,
_dwflags: u32,
) -> Result<()> {
let mut active = self.active.lock().unwrap();
match cpus {
CPUS_LOGON | CPUS_UNLOCK_WORKSTATION => {
// We activate for RDP logon and unlock scenarios
*active = true;
log::info!("NetBird CP activated for usage scenario {:?}", cpus.0);
Ok(())
}
_ => {
// Don't activate for credui or other scenarios
*active = false;
Err(E_NOTIMPL.into())
}
}
}
fn SetSerialization(
&self,
_pcpcs: *const CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn Advise(
&self,
_pcpe: Option<&ICredentialProviderEvents>,
_upadvisecontext: usize,
) -> Result<()> {
Ok(())
}
fn UnAdvise(&self) -> Result<()> {
Ok(())
}
fn GetFieldDescriptorCount(&self) -> Result<u32> {
// We have one field: a large text label showing "NetBird: Logging in as <user>"
Ok(1)
}
fn GetFieldDescriptorAt(
&self,
_dwindex: u32,
_ppcpfd: *mut *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR,
) -> Result<()> {
if _dwindex != 0 {
return Err(E_INVALIDARG.into());
}
let label = "NetBird Login";
let wide: Vec<u16> = label.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let desc = windows::Win32::System::Com::CoTaskMemAlloc(
std::mem::size_of::<CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR>(),
) as *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR;
if desc.is_null() {
return Err(E_OUTOFMEMORY.into());
}
let label_mem =
windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
if label_mem.is_null() {
windows::Win32::System::Com::CoTaskMemFree(Some(desc as *const _));
return Err(E_OUTOFMEMORY.into());
}
std::ptr::copy_nonoverlapping(wide.as_ptr(), label_mem, wide.len());
(*desc).dwFieldID = 0;
(*desc).cpft = CPFT_LARGE_TEXT;
(*desc).pszLabel = PWSTR(label_mem);
(*desc).guidFieldType = GUID::zeroed();
*_ppcpfd = desc;
}
Ok(())
}
fn GetCredentialCount(
&self,
_pdwcount: *mut u32,
_pdwdefault: *mut u32,
_pbautologinwithdefault: *mut BOOL,
) -> Result<()> {
let active = self.active.lock().unwrap();
if !*active {
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
return Ok(());
}
// Try to get the client IP of the current RDP session
let remote_ip = match get_rdp_client_ip() {
Some(ip) => ip,
None => {
log::debug!("NetBird CP: could not determine RDP client IP");
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
return Ok(());
}
};
// Query the NetBird agent for a pending session
match NamedPipeClient::query_pending(&remote_ip) {
Ok(response) if response.found => {
log::info!(
"NetBird CP: found pending session for {} -> {}",
remote_ip,
response.os_user
);
let cred = NetBirdCredential::new(remote_ip, response);
let icred: ICredentialProviderCredential = cred.into();
let mut credential = self.credential.lock().unwrap();
*credential = Some(icred);
unsafe {
*_pdwcount = 1;
*_pdwdefault = 0;
*_pbautologinwithdefault = TRUE; // auto-logon
}
}
Ok(_) => {
log::debug!("NetBird CP: no pending session for {}", remote_ip);
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
}
Err(e) => {
log::debug!("NetBird CP: pipe query failed: {}", e);
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
}
}
Ok(())
}
fn GetCredentialAt(
&self,
_dwindex: u32,
_ppcpc: *mut Option<ICredentialProviderCredential>,
) -> Result<()> {
if _dwindex != 0 {
return Err(E_INVALIDARG.into());
}
let credential = self.credential.lock().unwrap();
match &*credential {
Some(cred) => {
unsafe {
*_ppcpc = Some(cred.clone());
}
Ok(())
}
None => Err(E_UNEXPECTED.into()),
}
}
}
/// Get the IP address of the remote RDP client for the current session.
fn get_rdp_client_ip() -> Option<String> {
unsafe {
// Get the current session ID
let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
let mut session_id = 0u32;
if !windows::Win32::System::RemoteDesktop::ProcessIdToSessionId(process_id, &mut session_id)
.as_bool()
{
log::debug!("ProcessIdToSessionId failed");
return None;
}
// Query the client address
let mut buffer: *mut WTS_CLIENT_ADDRESS = std::ptr::null_mut();
let mut bytes_returned = 0u32;
let result = WTSQuerySessionInformationW(
WTS_CURRENT_SERVER_HANDLE,
session_id,
WTS_INFO_CLASS(14), // WTSClientAddress
&mut buffer as *mut _ as *mut *mut u16,
&mut bytes_returned,
);
if !result.as_bool() || buffer.is_null() {
log::debug!("WTSQuerySessionInformation(WTSClientAddress) failed");
return None;
}
let client_addr = &*buffer;
let ip = match client_addr.AddressFamily as u32 {
// AF_INET
2 => {
let addr = &client_addr.Address;
Some(format!("{}.{}.{}.{}", addr[2], addr[3], addr[4], addr[5]))
}
// AF_INET6
23 => {
// IPv6 - extract from Address bytes
let addr = &client_addr.Address;
Some(format!(
"{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9],
addr[10], addr[11], addr[12], addr[13], addr[14], addr[15], addr[16], addr[17]
))
}
_ => None,
};
WTSFreeMemory(buffer as *mut std::ffi::c_void);
ip
}
}

View File

@@ -0,0 +1,398 @@
//! S4U (Service for User) authentication for Windows.
//!
//! This module ports the S4U logon logic from the Go implementation at:
//! `client/ssh/server/executor_windows.go:generateS4UUserToken()`
//!
//! It creates Windows logon tokens without requiring a password, using the LSA
//! (Local Security Authority) S4U mechanism. This is the same approach used by
//! OpenSSH for Windows for public key authentication.
use std::ptr;
use windows::core::{PCSTR, PWSTR};
use windows::Win32::Foundation::{HANDLE, LUID, NTSTATUS, PSID};
use windows::Win32::Security::Authentication::Identity::{
LsaDeregisterLogonProcess, LsaFreeReturnBuffer, LsaLogonUser, LsaLookupAuthenticationPackage,
LsaRegisterLogonProcess, KERB_S4U_LOGON, MSV1_0_S4U_LOGON, MSV1_0_S4U_LOGON_FLAG_CHECK_LOGONHOURS,
SECURITY_LOGON_TYPE,
};
use windows::Win32::Security::{
QUOTA_LIMITS, TOKEN_SOURCE,
};
/// Status code for successful LSA operations.
const STATUS_SUCCESS: i32 = 0;
/// Network logon type (used for S4U).
const LOGON32_LOGON_NETWORK: SECURITY_LOGON_TYPE = SECURITY_LOGON_TYPE(3);
/// Kerberos S4U logon message type.
const KERB_S4U_LOGON_TYPE: u32 = 12;
/// MSV1_0 S4U logon message type.
const MSV1_0_S4U_LOGON_TYPE: u32 = 12;
/// Authentication package name for Kerberos.
const KERBEROS_PACKAGE: &str = "Kerberos";
/// Authentication package name for MSV1_0 (local users).
const MSV1_0_PACKAGE: &str = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0";
/// Result of a successful S4U logon.
pub struct S4UToken {
pub handle: HANDLE,
}
impl Drop for S4UToken {
fn drop(&mut self) {
if !self.handle.is_invalid() {
unsafe {
let _ = windows::Win32::Foundation::CloseHandle(self.handle);
}
}
}
}
/// Errors from S4U logon operations.
#[derive(Debug)]
pub enum S4UError {
LsaRegister(NTSTATUS),
LookupPackage(NTSTATUS),
LogonUser(NTSTATUS, i32),
AllocateLuid,
InvalidUsername(String),
Utf16Conversion(String),
}
impl std::fmt::Display for S4UError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
S4UError::LsaRegister(s) => write!(f, "LsaRegisterLogonProcess: 0x{:x}", s.0),
S4UError::LookupPackage(s) => write!(f, "LsaLookupAuthenticationPackage: 0x{:x}", s.0),
S4UError::LogonUser(s, sub) => {
write!(f, "LsaLogonUser S4U: NTSTATUS=0x{:x}, SubStatus=0x{:x}", s.0, sub)
}
S4UError::AllocateLuid => write!(f, "AllocateLocallyUniqueId failed"),
S4UError::InvalidUsername(u) => write!(f, "invalid username: {}", u),
S4UError::Utf16Conversion(s) => write!(f, "UTF-16 conversion: {}", s),
}
}
}
impl std::error::Error for S4UError {}
/// Generate a Windows logon token using S4U authentication.
///
/// This creates a token for the specified user without requiring a password.
/// The calling process must have SeTcbPrivilege (typically SYSTEM).
///
/// # Arguments
/// * `username` - The Windows username (without domain prefix)
/// * `domain` - The domain name ("." for local users)
///
/// # Returns
/// An `S4UToken` containing the Windows logon token handle.
pub fn generate_s4u_token(username: &str, domain: &str) -> Result<S4UToken, S4UError> {
if username.is_empty() {
return Err(S4UError::InvalidUsername("empty username".to_string()));
}
let is_local = is_local_user(domain);
// Initialize LSA connection
let lsa_handle = initialize_lsa_connection()?;
// Lookup authentication package
let auth_package_id = lookup_auth_package(lsa_handle, is_local)?;
// Perform S4U logon
let result = perform_s4u_logon(lsa_handle, auth_package_id, username, domain, is_local);
// Cleanup LSA connection
unsafe {
let _ = LsaDeregisterLogonProcess(lsa_handle);
}
result
}
fn is_local_user(domain: &str) -> bool {
domain.is_empty() || domain == "."
}
fn initialize_lsa_connection() -> Result<HANDLE, S4UError> {
let process_name = "NetBird\0";
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (process_name.len() - 1) as u16,
MaximumLength: process_name.len() as u16,
Buffer: windows::core::PSTR(process_name.as_ptr() as *mut u8),
};
let mut lsa_handle = HANDLE::default();
let mut mode = 0u32;
let status = unsafe {
LsaRegisterLogonProcess(&mut lsa_string, &mut lsa_handle, &mut mode)
};
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LsaRegister(status));
}
Ok(lsa_handle)
}
fn lookup_auth_package(lsa_handle: HANDLE, is_local: bool) -> Result<u32, S4UError> {
let package_name = if is_local { MSV1_0_PACKAGE } else { KERBEROS_PACKAGE };
let package_with_null = format!("{}\0", package_name);
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (package_with_null.len() - 1) as u16,
MaximumLength: package_with_null.len() as u16,
Buffer: windows::core::PSTR(package_with_null.as_ptr() as *mut u8),
};
let mut auth_package_id = 0u32;
let status = unsafe {
LsaLookupAuthenticationPackage(lsa_handle, &mut lsa_string, &mut auth_package_id)
};
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LookupPackage(status));
}
Ok(auth_package_id)
}
fn perform_s4u_logon(
lsa_handle: HANDLE,
auth_package_id: u32,
username: &str,
domain: &str,
is_local: bool,
) -> Result<S4UToken, S4UError> {
// Prepare token source
let mut source_name = [0u8; 8];
let name_bytes = b"netbird";
source_name[..name_bytes.len()].copy_from_slice(name_bytes);
let mut source_id = LUID::default();
let alloc_ok = unsafe {
windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime(
&mut std::mem::zeroed(),
);
// Use a simpler approach - just use the current time as a unique ID
source_id.LowPart = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
source_id.HighPart = std::process::id() as i32;
true
};
if !alloc_ok {
return Err(S4UError::AllocateLuid);
}
let token_source = TOKEN_SOURCE {
SourceName: source_name,
SourceIdentifier: source_id,
};
let origin_name_str = "netbird\0";
let mut origin_name = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (origin_name_str.len() - 1) as u16,
MaximumLength: origin_name_str.len() as u16,
Buffer: windows::core::PSTR(origin_name_str.as_ptr() as *mut u8),
};
// Build the logon info structure
let (logon_info_ptr, logon_info_size) = if is_local {
build_msv1_0_s4u_logon(username)?
} else {
build_kerb_s4u_logon(username, domain)?
};
let mut profile: *mut std::ffi::c_void = ptr::null_mut();
let mut profile_size = 0u32;
let mut logon_id = LUID::default();
let mut token = HANDLE::default();
let mut quotas = QUOTA_LIMITS::default();
let mut sub_status: i32 = 0;
let status = unsafe {
LsaLogonUser(
lsa_handle,
&mut origin_name,
LOGON32_LOGON_NETWORK,
auth_package_id,
logon_info_ptr as *const std::ffi::c_void,
logon_info_size as u32,
None, // local groups
&token_source,
&mut profile,
&mut profile_size,
&mut logon_id,
&mut token,
&mut quotas,
&mut sub_status,
)
};
// Free profile buffer if allocated
if !profile.is_null() {
unsafe {
let _ = LsaFreeReturnBuffer(profile);
}
}
// Free the logon info buffer
unsafe {
let layout = std::alloc::Layout::from_size_align_unchecked(logon_info_size, 8);
std::alloc::dealloc(logon_info_ptr as *mut u8, layout);
}
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LogonUser(status, sub_status));
}
Ok(S4UToken { handle: token })
}
/// Build MSV1_0_S4U_LOGON structure for local users.
fn build_msv1_0_s4u_logon(username: &str) -> Result<(*mut u8, usize), S4UError> {
let username_utf16: Vec<u16> = username.encode_utf16().chain(std::iter::once(0)).collect();
let domain_utf16: Vec<u16> = ".".encode_utf16().chain(std::iter::once(0)).collect();
let username_byte_size = username_utf16.len() * 2;
let domain_byte_size = domain_utf16.len() * 2;
// MSV1_0_S4U_LOGON structure:
// MessageType: u32 (4 bytes)
// Flags: u32 (4 bytes)
// UserPrincipalName: UNICODE_STRING (8 bytes on 32-bit, 16 bytes on 64-bit)
// DomainName: UNICODE_STRING
let struct_size = std::mem::size_of::<MSV1_0_S4U_LOGON_HEADER>();
let total_size = struct_size + username_byte_size + domain_byte_size;
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
if buffer.is_null() {
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
}
// For the POC, we'll set up the raw bytes manually since the windows-rs
// MSV1_0_S4U_LOGON structure layout may differ.
// This is a simplified version - in production, use proper FFI bindings.
unsafe {
// MessageType = MSV1_0_S4U_LOGON_TYPE (12)
*(buffer as *mut u32) = MSV1_0_S4U_LOGON_TYPE;
// Flags = 0
*((buffer as *mut u32).add(1)) = 0;
// Copy username UTF-16 after the structure
let username_offset = struct_size;
let username_dest = buffer.add(username_offset);
ptr::copy_nonoverlapping(
username_utf16.as_ptr() as *const u8,
username_dest,
username_byte_size,
);
// Copy domain UTF-16 after username
let domain_offset = username_offset + username_byte_size;
let domain_dest = buffer.add(domain_offset);
ptr::copy_nonoverlapping(
domain_utf16.as_ptr() as *const u8,
domain_dest,
domain_byte_size,
);
// Set UNICODE_STRING for UserPrincipalName (offset 8 on 64-bit)
// Length, MaximumLength, Buffer pointer
let upn_ptr = buffer.add(8) as *mut u16;
*upn_ptr = ((username_utf16.len() - 1) * 2) as u16; // Length (without null)
*(upn_ptr.add(1)) = (username_utf16.len() * 2) as u16; // MaximumLength
*((buffer.add(8 + 4)) as *mut *const u8) = username_dest; // Buffer
// Set UNICODE_STRING for DomainName
let dn_offset = 8 + std::mem::size_of::<UnicodeStringRaw>();
let dn_ptr = buffer.add(dn_offset) as *mut u16;
*dn_ptr = ((domain_utf16.len() - 1) * 2) as u16;
*(dn_ptr.add(1)) = (domain_utf16.len() * 2) as u16;
*((buffer.add(dn_offset + 4)) as *mut *const u8) = domain_dest;
}
Ok((buffer, total_size))
}
/// Build KERB_S4U_LOGON structure for domain users.
fn build_kerb_s4u_logon(username: &str, domain: &str) -> Result<(*mut u8, usize), S4UError> {
// Build UPN: username@domain
let upn = format!("{}@{}", username, domain);
let upn_utf16: Vec<u16> = upn.encode_utf16().chain(std::iter::once(0)).collect();
let upn_byte_size = upn_utf16.len() * 2;
let struct_size = std::mem::size_of::<KerbS4ULogonHeader>();
let total_size = struct_size + upn_byte_size;
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
if buffer.is_null() {
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
}
unsafe {
// MessageType = KERB_S4U_LOGON_TYPE (12)
*(buffer as *mut u32) = KERB_S4U_LOGON_TYPE;
// Flags = 0
*((buffer as *mut u32).add(1)) = 0;
// Copy UPN UTF-16 after the structure
let upn_offset = struct_size;
let upn_dest = buffer.add(upn_offset);
ptr::copy_nonoverlapping(
upn_utf16.as_ptr() as *const u8,
upn_dest,
upn_byte_size,
);
// Set UNICODE_STRING for ClientUpn (offset 8)
let upn_str_ptr = buffer.add(8) as *mut u16;
*upn_str_ptr = ((upn_utf16.len() - 1) * 2) as u16;
*(upn_str_ptr.add(1)) = (upn_utf16.len() * 2) as u16;
*((buffer.add(8 + 4)) as *mut *const u8) = upn_dest;
// ClientRealm is empty (zeroed)
}
Ok((buffer, total_size))
}
/// Raw UNICODE_STRING layout for size calculation.
#[repr(C)]
struct UnicodeStringRaw {
_length: u16,
_maximum_length: u16,
_buffer: *const u16,
}
/// Header size for MSV1_0_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
#[repr(C)]
struct MSV1_0_S4U_LOGON_HEADER {
_message_type: u32,
_flags: u32,
_user_principal_name: UnicodeStringRaw,
_domain_name: UnicodeStringRaw,
}
/// Header size for KERB_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
#[repr(C)]
struct KerbS4ULogonHeader {
_message_type: u32,
_flags: u32,
_client_upn: UnicodeStringRaw,
_client_realm: UnicodeStringRaw,
}

21
client/rdp/server/addr.go Normal file
View File

@@ -0,0 +1,21 @@
package server
import (
"fmt"
"net/netip"
)
// parseAddr parses a string into a netip.Addr, stripping any port or zone.
func parseAddr(s string) (netip.Addr, error) {
// Try as plain IP first
if addr, err := netip.ParseAddr(s); err == nil {
return addr, nil
}
// Try as IP:port
if addrPort, err := netip.ParseAddrPort(s); err == nil {
return addrPort.Addr(), nil
}
return netip.Addr{}, fmt.Errorf("invalid IP address: %s", s)
}

View File

@@ -0,0 +1,13 @@
//go:build !windows
package server
// RegisterCredentialProvider is a no-op on non-Windows platforms.
func RegisterCredentialProvider() error {
return nil
}
// UnregisterCredentialProvider is a no-op on non-Windows platforms.
func UnregisterCredentialProvider() error {
return nil
}

View File

@@ -0,0 +1,66 @@
//go:build windows
package server
import (
"fmt"
"os"
"os/exec"
"path/filepath"
log "github.com/sirupsen/logrus"
)
const (
// credProvDLLName is the filename of the credential provider DLL.
credProvDLLName = "netbird_credprov.dll"
)
// RegisterCredentialProvider registers the NetBird Credential Provider COM DLL
// using regsvr32. The DLL must be shipped alongside the NetBird executable.
func RegisterCredentialProvider() error {
dllPath, err := findCredProvDLL()
if err != nil {
return fmt.Errorf("find credential provider DLL: %w", err)
}
cmd := exec.Command("regsvr32", "/s", dllPath)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("regsvr32 %s: %w (output: %s)", dllPath, err, string(output))
}
log.Infof("registered RDP credential provider: %s", dllPath)
return nil
}
// UnregisterCredentialProvider unregisters the NetBird Credential Provider COM DLL.
func UnregisterCredentialProvider() error {
dllPath, err := findCredProvDLL()
if err != nil {
log.Debugf("credential provider DLL not found for unregistration: %v", err)
return nil
}
cmd := exec.Command("regsvr32", "/s", "/u", dllPath)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("regsvr32 /u %s: %w (output: %s)", dllPath, err, string(output))
}
log.Infof("unregistered RDP credential provider: %s", dllPath)
return nil
}
// findCredProvDLL locates the credential provider DLL next to the running executable.
func findCredProvDLL() (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("get executable path: %w", err)
}
dllPath := filepath.Join(filepath.Dir(exePath), credProvDLLName)
if _, err := os.Stat(dllPath); err != nil {
return "", fmt.Errorf("DLL not found at %s: %w", dllPath, err)
}
return dllPath, nil
}

View File

@@ -0,0 +1,184 @@
package server
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net/netip"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
)
const (
// DefaultSessionTTL is the default time-to-live for pending RDP sessions.
DefaultSessionTTL = 60 * time.Second
// cleanupInterval is how often the store checks for expired sessions.
cleanupInterval = 10 * time.Second
// nonceLength is the length of the nonce in bytes.
nonceLength = 32
)
// PendingRDPSession represents an authorized but not yet consumed RDP session.
type PendingRDPSession struct {
SessionID string
PeerIP netip.Addr
OSUsername string
Domain string
JWTUserID string // for audit trail
Nonce string // replay protection
CreatedAt time.Time
ExpiresAt time.Time
consumed bool
}
// PendingStore manages pending RDP session entries with automatic expiration.
type PendingStore struct {
mu sync.RWMutex
sessions map[string]*PendingRDPSession // keyed by SessionID
nonces map[string]struct{} // seen nonces for replay protection
ttl time.Duration
}
// NewPendingStore creates a new pending session store with the given TTL.
func NewPendingStore(ttl time.Duration) *PendingStore {
if ttl <= 0 {
ttl = DefaultSessionTTL
}
return &PendingStore{
sessions: make(map[string]*PendingRDPSession),
nonces: make(map[string]struct{}),
ttl: ttl,
}
}
// Add creates a new pending RDP session and returns it.
func (ps *PendingStore) Add(peerIP netip.Addr, osUsername, domain, jwtUserID, nonce string) (*PendingRDPSession, error) {
ps.mu.Lock()
defer ps.mu.Unlock()
// Check nonce for replay protection
if _, seen := ps.nonces[nonce]; seen {
return nil, fmt.Errorf("duplicate nonce: replay detected")
}
ps.nonces[nonce] = struct{}{}
now := time.Now()
session := &PendingRDPSession{
SessionID: uuid.New().String(),
PeerIP: peerIP,
OSUsername: osUsername,
Domain: domain,
JWTUserID: jwtUserID,
Nonce: nonce,
CreatedAt: now,
ExpiresAt: now.Add(ps.ttl),
}
ps.sessions[session.SessionID] = session
log.Debugf("RDP pending session created: id=%s peer=%s user=%s domain=%s expires=%s",
session.SessionID, peerIP, osUsername, domain, session.ExpiresAt.Format(time.RFC3339))
return session, nil
}
// QueryByPeerIP finds the first non-consumed, non-expired pending session for the given peer IP.
func (ps *PendingStore) QueryByPeerIP(peerIP netip.Addr) (*PendingRDPSession, bool) {
ps.mu.RLock()
defer ps.mu.RUnlock()
now := time.Now()
for _, session := range ps.sessions {
if session.PeerIP == peerIP && !session.consumed && now.Before(session.ExpiresAt) {
return session, true
}
}
return nil, false
}
// Consume marks a session as consumed (single-use). Returns true if the session
// was found and successfully consumed, false if it was already consumed, expired, or not found.
func (ps *PendingStore) Consume(sessionID string) bool {
ps.mu.Lock()
defer ps.mu.Unlock()
session, exists := ps.sessions[sessionID]
if !exists {
return false
}
if session.consumed {
log.Debugf("RDP pending session already consumed: id=%s", sessionID)
return false
}
if time.Now().After(session.ExpiresAt) {
log.Debugf("RDP pending session expired: id=%s", sessionID)
return false
}
session.consumed = true
log.Debugf("RDP pending session consumed: id=%s peer=%s user=%s",
sessionID, session.PeerIP, session.OSUsername)
return true
}
// StartCleanup runs a background goroutine that periodically removes expired sessions.
func (ps *PendingStore) StartCleanup(ctx context.Context) {
go func() {
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
ps.cleanup()
}
}
}()
}
// cleanup removes expired and consumed sessions.
func (ps *PendingStore) cleanup() {
ps.mu.Lock()
defer ps.mu.Unlock()
now := time.Now()
for id, session := range ps.sessions {
if now.After(session.ExpiresAt) || session.consumed {
delete(ps.sessions, id)
delete(ps.nonces, session.Nonce)
}
}
}
// Count returns the number of active (non-expired, non-consumed) sessions.
func (ps *PendingStore) Count() int {
ps.mu.RLock()
defer ps.mu.RUnlock()
count := 0
now := time.Now()
for _, session := range ps.sessions {
if !session.consumed && now.Before(session.ExpiresAt) {
count++
}
}
return count
}
// GenerateNonce creates a cryptographically random nonce for replay protection.
func GenerateNonce() (string, error) {
b := make([]byte, nonceLength)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
return hex.EncodeToString(b), nil
}

View File

@@ -0,0 +1,268 @@
package server
import (
"context"
"net/netip"
"sync"
"testing"
"time"
)
func TestPendingStore_AddAndQuery(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-1")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
if session.SessionID == "" {
t.Fatal("expected non-empty session ID")
}
if session.PeerIP != peerIP {
t.Errorf("expected peer IP %s, got %s", peerIP, session.PeerIP)
}
if session.OSUsername != "admin" {
t.Errorf("expected username admin, got %s", session.OSUsername)
}
// Query should find the session
found, ok := store.QueryByPeerIP(peerIP)
if !ok {
t.Fatal("expected to find pending session")
}
if found.SessionID != session.SessionID {
t.Errorf("expected session %s, got %s", session.SessionID, found.SessionID)
}
// Query for different IP should not find anything
_, ok = store.QueryByPeerIP(netip.MustParseAddr("100.64.0.2"))
if ok {
t.Fatal("expected no session for different IP")
}
}
func TestPendingStore_Consume(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-2")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// First consume should succeed
if !store.Consume(session.SessionID) {
t.Fatal("expected first consume to succeed")
}
// Second consume should fail (already consumed)
if store.Consume(session.SessionID) {
t.Fatal("expected second consume to fail")
}
// Query should no longer find consumed session
_, ok := store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected consumed session to not be found by query")
}
}
func TestPendingStore_Expiry(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-3")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// Should be found immediately
_, ok := store.QueryByPeerIP(peerIP)
if !ok {
t.Fatal("expected to find session before expiry")
}
// Wait for expiry
time.Sleep(100 * time.Millisecond)
// Should not be found after expiry
_, ok = store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected session to be expired")
}
// Consume should also fail
if store.Consume(session.SessionID) {
t.Fatal("expected consume of expired session to fail")
}
}
func TestPendingStore_ReplayProtection(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
if err != nil {
t.Fatalf("first Add failed: %v", err)
}
// Same nonce should be rejected
_, err = store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
if err == nil {
t.Fatal("expected duplicate nonce to be rejected")
}
}
func TestPendingStore_Cleanup(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-cleanup")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
if store.Count() != 1 {
t.Fatalf("expected count 1, got %d", store.Count())
}
// Wait for expiry then trigger cleanup
time.Sleep(100 * time.Millisecond)
store.cleanup()
if store.Count() != 0 {
t.Fatalf("expected count 0 after cleanup, got %d", store.Count())
}
}
func TestPendingStore_CleanupBackground(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
store.StartCleanup(ctx)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-bg-cleanup")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// Wait for expiry + cleanup interval
time.Sleep(200 * time.Millisecond)
_, ok := store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected session to be cleaned up")
}
}
func TestPendingStore_ConcurrentAccess(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
ip := netip.AddrFrom4([4]byte{100, 64, byte(i / 256), byte(i % 256)})
nonce := "nonce-" + string(rune(i+'A'))
if i >= 26 {
nonce = "nonce-" + string(rune(i-26+'a'))
}
session, err := store.Add(ip, "admin", ".", "user", nonce)
if err != nil {
return // nonce collision in test is expected
}
store.QueryByPeerIP(ip)
store.Consume(session.SessionID)
}(i)
}
wg.Wait()
}
func TestPendingStore_MultipleSessions(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
ip1 := netip.MustParseAddr("100.64.0.1")
ip2 := netip.MustParseAddr("100.64.0.2")
s1, err := store.Add(ip1, "admin", ".", "user1", "nonce-a")
if err != nil {
t.Fatalf("Add s1 failed: %v", err)
}
s2, err := store.Add(ip2, "jdoe", "DOMAIN", "user2", "nonce-b")
if err != nil {
t.Fatalf("Add s2 failed: %v", err)
}
// Query each
found1, ok := store.QueryByPeerIP(ip1)
if !ok || found1.SessionID != s1.SessionID {
t.Fatal("expected to find s1")
}
found2, ok := store.QueryByPeerIP(ip2)
if !ok || found2.SessionID != s2.SessionID {
t.Fatal("expected to find s2")
}
if found2.Domain != "DOMAIN" {
t.Errorf("expected domain DOMAIN, got %s", found2.Domain)
}
if store.Count() != 2 {
t.Errorf("expected count 2, got %d", store.Count())
}
}
func TestGenerateNonce(t *testing.T) {
nonce1, err := GenerateNonce()
if err != nil {
t.Fatalf("GenerateNonce failed: %v", err)
}
nonce2, err := GenerateNonce()
if err != nil {
t.Fatalf("GenerateNonce failed: %v", err)
}
if len(nonce1) != nonceLength*2 { // hex encoding doubles the length
t.Errorf("expected nonce length %d, got %d", nonceLength*2, len(nonce1))
}
if nonce1 == nonce2 {
t.Error("expected unique nonces")
}
}
func TestParseWindowsUsername(t *testing.T) {
tests := []struct {
input string
expectedUser string
expectedDomain string
}{
{"admin", "admin", "."},
{"DOMAIN\\admin", "admin", "DOMAIN"},
{"admin@domain.com", "admin", "domain.com"},
{".\\localuser", "localuser", "."},
}
for _, tt := range tests {
user, domain := parseWindowsUsername(tt.input)
if user != tt.expectedUser {
t.Errorf("parseWindowsUsername(%q) user = %q, want %q", tt.input, user, tt.expectedUser)
}
if domain != tt.expectedDomain {
t.Errorf("parseWindowsUsername(%q) domain = %q, want %q", tt.input, domain, tt.expectedDomain)
}
}
}

View File

@@ -0,0 +1,19 @@
//go:build !windows
package server
import "context"
type stubPipeServer struct{}
func newPipeServer(_ *PendingStore) PipeServer {
return &stubPipeServer{}
}
func (s *stubPipeServer) Start(_ context.Context) error {
return nil
}
func (s *stubPipeServer) Stop() error {
return nil
}

View File

@@ -0,0 +1,164 @@
//go:build windows
package server
import (
"context"
"encoding/json"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/Microsoft/go-winio"
)
const (
// PipeName is the named pipe path used for IPC between the NetBird agent and
// the Credential Provider DLL.
PipeName = `\\.\pipe\netbird-rdp-auth`
// pipeSDDL restricts access to LOCAL_SYSTEM (SY) and Administrators (BA).
pipeSDDL = "D:P(A;;GA;;;SY)(A;;GA;;;BA)"
// maxPipeRequestSize is the maximum size of a pipe request in bytes.
maxPipeRequestSize = 4096
)
// windowsPipeServer implements the PipeServer interface for Windows.
type windowsPipeServer struct {
pending *PendingStore
listener net.Listener
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
func newPipeServer(pending *PendingStore) PipeServer {
return &windowsPipeServer{
pending: pending,
}
}
func (s *windowsPipeServer) Start(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
s.ctx, s.cancel = context.WithCancel(ctx)
cfg := &winio.PipeConfig{
SecurityDescriptor: pipeSDDL,
}
listener, err := winio.ListenPipe(PipeName, cfg)
if err != nil {
return err
}
s.listener = listener
go s.acceptLoop()
log.Infof("RDP named pipe server started on %s", PipeName)
return nil
}
func (s *windowsPipeServer) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return nil
}
func (s *windowsPipeServer) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
if s.ctx.Err() != nil {
return
}
log.Debugf("RDP pipe accept error: %v", err)
continue
}
go s.handlePipeConnection(conn)
}
}
func (s *windowsPipeServer) handlePipeConnection(conn net.Conn) {
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("RDP pipe close: %v", err)
}
}()
data, err := io.ReadAll(io.LimitReader(conn, maxPipeRequestSize))
if err != nil {
log.Debugf("RDP pipe read: %v", err)
return
}
var req PipeRequest
if err := json.Unmarshal(data, &req); err != nil {
log.Debugf("RDP pipe unmarshal: %v", err)
return
}
var resp PipeResponse
switch req.Action {
case PipeActionQuery:
resp = s.handleQuery(req.RemoteIP)
case PipeActionConsume:
resp = s.handleConsume(req.SessionID)
default:
log.Debugf("RDP pipe unknown action: %s", req.Action)
return
}
respData, err := json.Marshal(resp)
if err != nil {
log.Debugf("RDP pipe marshal response: %v", err)
return
}
if _, err := conn.Write(respData); err != nil {
log.Debugf("RDP pipe write response: %v", err)
}
}
func (s *windowsPipeServer) handleQuery(remoteIP string) PipeResponse {
peerIP, err := parseAddr(remoteIP)
if err != nil {
log.Debugf("RDP pipe invalid remote IP: %s", remoteIP)
return PipeResponse{Found: false}
}
session, found := s.pending.QueryByPeerIP(peerIP)
if !found {
return PipeResponse{Found: false}
}
return PipeResponse{
Found: true,
SessionID: session.SessionID,
OSUser: session.OSUsername,
Domain: session.Domain,
}
}
func (s *windowsPipeServer) handleConsume(sessionID string) PipeResponse {
if s.pending.Consume(sessionID) {
return PipeResponse{Found: true, SessionID: sessionID}
}
return PipeResponse{Found: false}
}

View File

@@ -0,0 +1,48 @@
package server
// AuthRequest is the sideband authorization request sent by the connecting peer
// to the target peer's RDP auth server over the WireGuard tunnel.
type AuthRequest struct {
JWTToken string `json:"jwt_token"`
RequestedUser string `json:"requested_user"`
ClientPeerIP string `json:"client_peer_ip"`
Nonce string `json:"nonce"`
}
// AuthResponse is the sideband authorization response sent by the target peer
// back to the connecting peer.
type AuthResponse struct {
Status string `json:"status"` // "authorized" or "denied"
SessionID string `json:"session_id,omitempty"`
ExpiresAt int64 `json:"expires_at,omitempty"` // unix timestamp
OSUser string `json:"os_user,omitempty"`
Reason string `json:"reason,omitempty"`
}
// PipeRequest is the IPC request from the Credential Provider DLL to the NetBird agent
// via the named pipe.
type PipeRequest struct {
Action string `json:"action"` // "query_pending" or "consume"
RemoteIP string `json:"remote_ip"` // connecting peer's WG IP
SessionID string `json:"session_id,omitempty"` // for consume action
}
// PipeResponse is the IPC response from the NetBird agent to the Credential Provider DLL.
type PipeResponse struct {
Found bool `json:"found"`
SessionID string `json:"session_id,omitempty"`
OSUser string `json:"os_user,omitempty"`
Domain string `json:"domain,omitempty"`
}
const (
// StatusAuthorized indicates the RDP session was authorized.
StatusAuthorized = "authorized"
// StatusDenied indicates the RDP session was denied.
StatusDenied = "denied"
// PipeActionQuery queries for a pending session by remote IP.
PipeActionQuery = "query_pending"
// PipeActionConsume marks a pending session as consumed.
PipeActionConsume = "consume"
)

301
client/rdp/server/server.go Normal file
View File

@@ -0,0 +1,301 @@
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
)
const (
// InternalRDPAuthPort is the port the sideband auth server listens on.
InternalRDPAuthPort = 22338
// DefaultRDPAuthPort is the external port on the WireGuard interface (DNAT target).
DefaultRDPAuthPort = 22338
// maxRequestSize is the maximum size of an auth request in bytes.
maxRequestSize = 64 * 1024
// connectionTimeout is the timeout for a single auth connection.
connectionTimeout = 30 * time.Second
)
// JWTValidator validates JWT tokens and extracts user identity.
type JWTValidator interface {
ValidateAndExtract(token string) (userID string, err error)
}
// Authorizer checks if a user is authorized for RDP access.
type Authorizer interface {
Authorize(jwtUserID, osUsername string) (string, error)
}
// Server is the sideband RDP authorization server that listens on the WireGuard interface.
type Server struct {
listener net.Listener
pending *PendingStore
pipeServer PipeServer
jwtValidator JWTValidator
authorizer Authorizer
sshAuthorizer *sshauth.Authorizer // reuses SSH ACL for RDP access control
networkAddr netip.Prefix // WireGuard network for source IP validation
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
// PipeServer is the interface for the named pipe IPC server (platform-specific).
type PipeServer interface {
Start(ctx context.Context) error
Stop() error
}
// Config holds the configuration for the RDP auth server.
type Config struct {
JWTValidator JWTValidator
Authorizer Authorizer
NetworkAddr netip.Prefix
SessionTTL time.Duration
}
// New creates a new RDP sideband auth server.
func New(cfg *Config) *Server {
ttl := cfg.SessionTTL
if ttl <= 0 {
ttl = DefaultSessionTTL
}
pending := NewPendingStore(ttl)
return &Server{
pending: pending,
pipeServer: newPipeServer(pending),
jwtValidator: cfg.JWTValidator,
authorizer: cfg.Authorizer,
sshAuthorizer: sshauth.NewAuthorizer(),
networkAddr: cfg.NetworkAddr,
}
}
// UpdateRDPAuth updates the RDP authorization config (reuses SSH ACL).
func (s *Server) UpdateRDPAuth(config *sshauth.Config) {
s.sshAuthorizer.Update(config)
log.Debugf("RDP auth: updated authorization config")
}
// Start begins listening for sideband auth requests on the given address.
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return errors.New("RDP auth server already running")
}
s.ctx, s.cancel = context.WithCancel(ctx)
listenAddr := net.TCPAddrFromAddrPort(addr)
listener, err := net.ListenTCP("tcp", listenAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
s.listener = listener
s.pending.StartCleanup(s.ctx)
if s.pipeServer != nil {
if err := s.pipeServer.Start(s.ctx); err != nil {
log.Warnf("failed to start RDP named pipe server: %v", err)
}
}
go s.acceptLoop()
log.Infof("RDP sideband auth server started on %s", addr)
return nil
}
// Stop shuts down the server and cleans up resources.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
}
if s.pipeServer != nil {
if err := s.pipeServer.Stop(); err != nil {
log.Warnf("failed to stop RDP named pipe server: %v", err)
}
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
if err != nil {
return fmt.Errorf("close listener: %w", err)
}
}
log.Info("RDP sideband auth server stopped")
return nil
}
// GetPendingStore returns the pending session store (for testing/named pipe access).
func (s *Server) GetPendingStore() *PendingStore {
return s.pending
}
func (s *Server) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
if s.ctx.Err() != nil {
return
}
log.Debugf("RDP auth accept error: %v", err)
continue
}
go s.handleConnection(conn)
}
}
func (s *Server) handleConnection(conn net.Conn) {
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("RDP auth close connection: %v", err)
}
}()
if err := conn.SetDeadline(time.Now().Add(connectionTimeout)); err != nil {
log.Debugf("RDP auth set deadline: %v", err)
return
}
// Validate source IP is from WireGuard network
remoteAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
if err != nil {
log.Debugf("RDP auth parse remote addr: %v", err)
return
}
if !s.networkAddr.Contains(remoteAddr.Addr()) {
log.Warnf("RDP auth rejected connection from non-WG address: %s", remoteAddr.Addr())
return
}
// Read request
data, err := io.ReadAll(io.LimitReader(conn, maxRequestSize))
if err != nil {
log.Debugf("RDP auth read request: %v", err)
return
}
var req AuthRequest
if err := json.Unmarshal(data, &req); err != nil {
log.Debugf("RDP auth unmarshal request: %v", err)
s.sendResponse(conn, &AuthResponse{Status: StatusDenied, Reason: "invalid request format"})
return
}
response := s.processAuthRequest(remoteAddr.Addr(), &req)
s.sendResponse(conn, response)
}
func (s *Server) processAuthRequest(peerIP netip.Addr, req *AuthRequest) *AuthResponse {
// Validate JWT
if s.jwtValidator == nil {
// No JWT validation configured - for POC, accept all requests from WG peers
log.Warnf("RDP auth: no JWT validator configured, accepting request from %s", peerIP)
return s.createSession(peerIP, req, "no-jwt-validation")
}
userID, err := s.jwtValidator.ValidateAndExtract(req.JWTToken)
if err != nil {
log.Warnf("RDP auth JWT validation failed for %s: %v", peerIP, err)
return &AuthResponse{Status: StatusDenied, Reason: "JWT validation failed"}
}
// Check authorization - try explicit authorizer first, then SSH ACL
if s.authorizer != nil {
if _, err := s.authorizer.Authorize(userID, req.RequestedUser); err != nil {
log.Warnf("RDP auth denied for user %s -> %s: %v", userID, req.RequestedUser, err)
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
}
} else if s.sshAuthorizer != nil {
if _, err := s.sshAuthorizer.Authorize(userID, req.RequestedUser); err != nil {
log.Warnf("RDP auth denied (SSH ACL) for user %s -> %s: %v", userID, req.RequestedUser, err)
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
}
}
return s.createSession(peerIP, req, userID)
}
func (s *Server) createSession(peerIP netip.Addr, req *AuthRequest, jwtUserID string) *AuthResponse {
// Parse domain from requested user (DOMAIN\user or user@domain)
osUser, domain := parseWindowsUsername(req.RequestedUser)
session, err := s.pending.Add(peerIP, osUser, domain, jwtUserID, req.Nonce)
if err != nil {
log.Warnf("RDP auth create session failed: %v", err)
return &AuthResponse{Status: StatusDenied, Reason: err.Error()}
}
return &AuthResponse{
Status: StatusAuthorized,
SessionID: session.SessionID,
ExpiresAt: session.ExpiresAt.Unix(),
OSUser: session.OSUsername,
}
}
func (s *Server) sendResponse(conn net.Conn, resp *AuthResponse) {
data, err := json.Marshal(resp)
if err != nil {
log.Debugf("RDP auth marshal response: %v", err)
return
}
if _, err := conn.Write(data); err != nil {
log.Debugf("RDP auth write response: %v", err)
}
}
// parseWindowsUsername extracts username and domain from Windows username formats.
// Supports DOMAIN\username, username@domain, and plain username.
func parseWindowsUsername(fullUsername string) (username, domain string) {
for i := len(fullUsername) - 1; i >= 0; i-- {
if fullUsername[i] == '\\' {
return fullUsername[i+1:], fullUsername[:i]
}
}
if idx := indexOf(fullUsername, '@'); idx != -1 {
return fullUsername[:idx], fullUsername[idx+1:]
}
return fullUsername, "."
}
func indexOf(s string, c byte) int {
for i := 0; i < len(s); i++ {
if s[i] == c {
return i
}
}
return -1
}

View File

@@ -366,6 +366,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerRDPAllowed = msg.ServerRDPAllowed
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -1359,6 +1360,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
if engine.IsBlockInbound() {
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")
@@ -1510,6 +1515,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerRDPAllowed: cfg.ServerRDPAllowed != nil && *cfg.ServerRDPAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,

View File

@@ -58,6 +58,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverRDPAllowed := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -82,6 +83,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerRDPAllowed: &serverRDPAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -125,6 +127,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerRDPAllowed)
require.Equal(t, serverRDPAllowed, *cfg.ServerRDPAllowed)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -176,6 +180,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerRDPAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -236,6 +241,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-rdp": "ServerRDPAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
// registerStates registers all states that need crash recovery cleanup.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
// registerStates registers all states that need crash recovery cleanup.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
hasCommand := session.RawCommand() != ""
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
}
if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
if err := serverSession.Run(session.RawCommand()); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
}

View File

@@ -1,6 +1,7 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte

View File

@@ -284,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) {
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.sshServer == nil {
sshServer := s.sshServer
if sshServer == nil {
s.mu.Unlock()
return nil
}
s.sshServer = nil
s.listener = nil
s.mu.Unlock()
if err := s.sshServer.Close(); err != nil {
// Close outside the lock: session handlers need s.mu for unregisterSession.
if err := sshServer.Close(); err != nil {
log.Debugf("close SSH server: %v", err)
}
s.sshServer = nil
s.listener = nil
s.mu.Lock()
maps.Clear(s.sessions)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.connections)
@@ -307,6 +309,7 @@ func (s *Server) Stop() error {
}
}
maps.Clear(s.remoteForwardListeners)
s.mu.Unlock()
return nil
}

View File

@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
}
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
hasCommand := session.RawCommand() != ""
if isPty && !hasCommand {
// ssh <host> - PTY interactive session (login)

View File

@@ -153,6 +153,9 @@ func networkAddresses() ([]NetworkAddress, error) {
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}

View File

@@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info {
systemHostname, _ := os.Hostname()
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
}
return &Info{
GoOS: runtime.GOOS,
Kernel: osInfo[0],
Platform: runtime.GOARCH,
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
Environment: env,
GoOS: runtime.GOOS,
Kernel: osInfo[0],
Platform: runtime.GOARCH,
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
NetworkAddresses: addrs,
Environment: env,
}
}

View File

@@ -24,9 +24,10 @@ import (
// Initial state for the debug collection
type debugInitialState struct {
wasDown bool
logLevel proto.LogLevel
isLevelTrace bool
wasDown bool
needsRestoreUp bool
logLevel proto.LogLevel
isLevelTrace bool
}
// Debug collection parameters
@@ -371,46 +372,51 @@ func (s *serviceClient) configureServiceForDebug(
conn proto.DaemonServiceClient,
state *debugInitialState,
enablePersistence bool,
) error {
) {
if state.wasDown {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service up: %v", err)
log.Warnf("failed to bring service up: %v", err)
} else {
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
}
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
}
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
return fmt.Errorf("set log level to TRACE: %v", err)
log.Warnf("failed to set log level to TRACE: %v", err)
} else {
log.Info("Log level set to TRACE for debug")
}
log.Info("Log level set to TRACE for debug")
}
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
return fmt.Errorf("bring service down: %v", err)
log.Warnf("failed to bring service down: %v", err)
} else {
state.needsRestoreUp = !state.wasDown
time.Sleep(time.Second)
}
time.Sleep(time.Second)
if enablePersistence {
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("enable sync response persistence: %v", err)
log.Warnf("failed to enable sync response persistence: %v", err)
} else {
log.Info("Sync response persistence enabled for debug")
}
log.Info("Sync response persistence enabled for debug")
}
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service back up: %v", err)
log.Warnf("failed to bring service back up: %v", err)
} else {
state.needsRestoreUp = false
time.Sleep(time.Second * 3)
}
time.Sleep(time.Second * 3)
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
log.Warnf("failed to start CPU profiling: %v", err)
}
return nil
}
func (s *serviceClient) collectDebugData(
@@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData(
var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
return err
}
s.configureServiceForDebug(conn, state, params.enablePersistence)
wg.Wait()
progress.progressBar.Hide()
@@ -482,9 +486,17 @@ func (s *serviceClient) createDebugBundleFromCollection(
// Restore service to original state
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
if state.needsRestoreUp {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Warnf("failed to restore up state: %v", err)
} else {
log.Info("Service state restored to up")
}
}
if state.wasDown {
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("Failed to restore down state: %v", err)
log.Warnf("failed to restore down state: %v", err)
} else {
log.Info("Service state restored to down")
}
@@ -492,7 +504,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
log.Errorf("Failed to restore log level: %v", err)
log.Warnf("failed to restore log level: %v", err)
} else {
log.Info("Log level restored to original setting")
}

View File

@@ -179,9 +179,11 @@ type StoreConfig struct {
// ReverseProxyConfig contains reverse proxy settings
type ReverseProxyConfig struct {
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
AccessLogRetentionDays int `yaml:"accessLogRetentionDays"`
AccessLogCleanupIntervalHours int `yaml:"accessLogCleanupIntervalHours"`
}
// DefaultConfig returns a CombinedConfig with default values
@@ -645,7 +647,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
// Build reverse proxy config
reverseProxy := nbconfig.ReverseProxy{
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
AccessLogRetentionDays: mgmt.ReverseProxy.AccessLogRetentionDays,
AccessLogCleanupIntervalHours: mgmt.ReverseProxy.AccessLogCleanupIntervalHours,
}
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
if prefix, err := netip.ParsePrefix(p); err == nil {

Some files were not shown because too many files have changed in this diff Show More