mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Compare commits
49 Commits
feature/fl
...
claude/rdp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b621a2628b | ||
|
|
4949ca6194 | ||
|
|
c5186f1483 | ||
|
|
5259e5df51 | ||
|
|
ebd78e0122 | ||
|
|
cf86b9a528 | ||
|
|
ee588e1536 | ||
|
|
2a8aacc5c9 | ||
|
|
15709bc666 | ||
|
|
789b4113fe | ||
|
|
d2cdc0efec | ||
|
|
ee343d5d77 | ||
|
|
099c493b18 | ||
|
|
c1d1229ae0 | ||
|
|
94a36cb53e | ||
|
|
c7ba931466 | ||
|
|
413d95b740 | ||
|
|
332c624c55 | ||
|
|
dc160aff36 | ||
|
|
96806bf55f | ||
|
|
d33cd4c95b | ||
|
|
e2c2f64be7 | ||
|
|
cb73b94ffb | ||
|
|
1d920d700c | ||
|
|
bb85eee40a | ||
|
|
aba5d6f0d2 | ||
|
|
0588d2dbe1 | ||
|
|
14b3b77bda | ||
|
|
6da34e483c | ||
|
|
0efef671d7 | ||
|
|
435203b13b | ||
|
|
decb5dd3af | ||
|
|
28fbf96b2a | ||
|
|
9d1a37c644 | ||
|
|
5bf2372c4d | ||
|
|
c2c6396a04 | ||
|
|
aaf813fc0c | ||
|
|
d97fe84296 | ||
|
|
81f45dab21 | ||
|
|
d670e7382a | ||
|
|
cd8c686339 | ||
|
|
f5c41e3018 | ||
|
|
2477f99d89 | ||
|
|
940f530ac2 | ||
|
|
4d3e2f8ad3 | ||
|
|
5ae986e1c4 | ||
|
|
e5914e4e8b | ||
|
|
c238f5425f | ||
|
|
3c3097ea74 |
@@ -31,7 +31,7 @@ jobs:
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
@@ -154,6 +154,26 @@ builds:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-idp-migrate
|
||||
dir: tools/idp-migrate
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-idp-migrate
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -166,6 +186,10 @@ archives:
|
||||
- netbird-wasm
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||
format: binary
|
||||
- id: netbird-idp-migrate
|
||||
builds:
|
||||
- netbird-idp-migrate
|
||||
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
|
||||
nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
|
||||
@@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
needsRestoreUp := false
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = !stateWasDown
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
|
||||
@@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = false
|
||||
cmd.Println("netbird up")
|
||||
}
|
||||
|
||||
@@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if needsRestoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up (restored)")
|
||||
}
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -201,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
|
||||
stream, err := client.ExposeService(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %w", err)
|
||||
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
@@ -236,7 +237,7 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %w", err)
|
||||
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||
|
||||
276
client/cmd/rdp.go
Normal file
276
client/cmd/rdp.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
rdpclient "github.com/netbirdio/netbird/client/rdp/client"
|
||||
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
serverRDPAllowedFlag = "allow-server-rdp"
|
||||
)
|
||||
|
||||
var (
|
||||
rdpUsername string
|
||||
rdpHost string
|
||||
rdpNoBrowser bool
|
||||
rdpNoCache bool
|
||||
serverRDPAllowed bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
rdpCmd.PersistentFlags().StringVarP(&rdpUsername, "user", "u", "", "Windows username on remote peer")
|
||||
rdpCmd.PersistentFlags().BoolVar(&rdpNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
rdpCmd.PersistentFlags().BoolVar(&rdpNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&serverRDPAllowed, serverRDPAllowedFlag, false, "Allow RDP passthrough on peer (passwordless RDP via credential provider)")
|
||||
}
|
||||
|
||||
var rdpCmd = &cobra.Command{
|
||||
Use: "rdp [flags] [user@]host",
|
||||
Short: "Connect to a NetBird peer via RDP (passwordless)",
|
||||
Long: `Connect to a NetBird peer using Remote Desktop Protocol with token-based
|
||||
passwordless authentication. The target peer must have RDP passthrough enabled.
|
||||
|
||||
This command:
|
||||
1. Obtains a JWT token via OIDC authentication
|
||||
2. Sends the token to the target peer's sideband auth service
|
||||
3. If authorized, launches mstsc.exe to connect
|
||||
|
||||
Examples:
|
||||
netbird rdp peer-hostname
|
||||
netbird rdp administrator@peer-hostname
|
||||
netbird rdp --user admin peer-hostname`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: rdpFn,
|
||||
}
|
||||
|
||||
func rdpFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
// Parse user@host
|
||||
if err := parseRDPHostArg(args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
rdpCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runRDP(rdpCtx, cmd); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-rdpCtx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-rdpCtx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRDPHostArg(arg string) error {
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.New("invalid user@host format")
|
||||
}
|
||||
if rdpUsername == "" {
|
||||
rdpUsername = parts[0]
|
||||
}
|
||||
rdpHost = parts[1]
|
||||
} else {
|
||||
rdpHost = arg
|
||||
}
|
||||
|
||||
if rdpUsername == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
rdpUsername = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
rdpUsername = currentUser.Username
|
||||
} else {
|
||||
rdpUsername = "Administrator"
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runRDP(ctx context.Context, cmd *cobra.Command) error {
|
||||
// Connect to daemon
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(grpcConn)
|
||||
|
||||
// Resolve peer IP
|
||||
peerIP, err := resolvePeerIP(ctx, daemonClient, rdpHost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve peer %s: %w", rdpHost, err)
|
||||
}
|
||||
|
||||
cmd.Printf("Connecting to %s@%s (%s)...\n", rdpUsername, rdpHost, peerIP)
|
||||
|
||||
// Obtain JWT token
|
||||
hint := profilemanager.GetLoginHint()
|
||||
var browserOpener func(string) error
|
||||
if !rdpNoBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
jwtToken, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !rdpNoCache, hint, browserOpener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("JWT authentication: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("JWT authentication successful")
|
||||
cmd.Println("Authenticated. Requesting RDP access...")
|
||||
|
||||
// Generate nonce for replay protection
|
||||
nonce, err := rdpserver.GenerateNonce()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Send sideband auth request
|
||||
authClient := rdpclient.New()
|
||||
authAddr := net.JoinHostPort(peerIP, fmt.Sprintf("%d", rdpserver.DefaultRDPAuthPort))
|
||||
|
||||
resp, err := authClient.RequestAuth(ctx, authAddr, &rdpserver.AuthRequest{
|
||||
JWTToken: jwtToken,
|
||||
RequestedUser: rdpUsername,
|
||||
ClientPeerIP: "", // will be filled by the server from the connection
|
||||
Nonce: nonce,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to authorize RDP session with %s\n", rdpHost)
|
||||
cmd.Printf("\nTroubleshooting:\n")
|
||||
cmd.Printf(" 1. Check connectivity: netbird status -d\n")
|
||||
cmd.Printf(" 2. Verify RDP passthrough is enabled on the target peer\n")
|
||||
return fmt.Errorf("sideband auth: %w", err)
|
||||
}
|
||||
|
||||
if resp.Status != rdpserver.StatusAuthorized {
|
||||
return fmt.Errorf("RDP access denied: %s", resp.Reason)
|
||||
}
|
||||
|
||||
cmd.Printf("RDP access authorized (session: %s, user: %s)\n", resp.SessionID, resp.OSUser)
|
||||
cmd.Printf("Launching Remote Desktop client...\n")
|
||||
|
||||
// Launch mstsc.exe (platform-specific)
|
||||
if err := launchRDPClient(peerIP); err != nil {
|
||||
return fmt.Errorf("launch RDP client: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolvePeerIP resolves a peer hostname/FQDN to its WireGuard IP address
|
||||
// by querying the daemon for the current peer status.
|
||||
func resolvePeerIP(ctx context.Context, client proto.DaemonServiceClient, peerAddress string) (string, error) {
|
||||
statusResp, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get daemon status: %w", err)
|
||||
}
|
||||
|
||||
if statusResp.GetFullStatus() == nil {
|
||||
return "", errors.New("daemon returned empty status")
|
||||
}
|
||||
|
||||
for _, peer := range statusResp.GetFullStatus().GetPeers() {
|
||||
if matchesPeer(peer, peerAddress) {
|
||||
ip := peer.GetIP()
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
// Strip CIDR suffix if present
|
||||
if idx := strings.Index(ip, "/"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If not found as a peer name, try as a direct IP
|
||||
if addr, err := net.ResolveIPAddr("ip", peerAddress); err == nil {
|
||||
return addr.String(), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("peer %q not found in network", peerAddress)
|
||||
}
|
||||
|
||||
func matchesPeer(peer *proto.PeerState, address string) bool {
|
||||
address = strings.ToLower(address)
|
||||
|
||||
if strings.EqualFold(peer.GetFqdn(), address) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match against FQDN without trailing dot
|
||||
fqdn := strings.TrimSuffix(peer.GetFqdn(), ".")
|
||||
if strings.EqualFold(fqdn, address) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match against short hostname (first part of FQDN)
|
||||
if parts := strings.SplitN(fqdn, ".", 2); len(parts) > 0 {
|
||||
if strings.EqualFold(parts[0], address) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Match against IP
|
||||
ip := peer.GetIP()
|
||||
if idx := strings.Index(ip, "/"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
if ip == address {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
13
client/cmd/rdp_stub.go
Normal file
13
client/cmd/rdp_stub.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
import "fmt"
|
||||
|
||||
// launchRDPClient is a stub for non-Windows platforms.
|
||||
func launchRDPClient(peerIP string) error {
|
||||
fmt.Printf("RDP session authorized for %s\n", peerIP)
|
||||
fmt.Println("Note: mstsc.exe is only available on Windows.")
|
||||
fmt.Printf("Use any RDP client to connect to %s:3389\n", peerIP)
|
||||
return nil
|
||||
}
|
||||
34
client/cmd/rdp_windows.go
Normal file
34
client/cmd/rdp_windows.go
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// launchRDPClient launches the native Windows Remote Desktop client (mstsc.exe).
|
||||
func launchRDPClient(peerIP string) error {
|
||||
mstscPath, err := exec.LookPath("mstsc.exe")
|
||||
if err != nil {
|
||||
return fmt.Errorf("mstsc.exe not found: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command(mstscPath, fmt.Sprintf("/v:%s", peerIP))
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start mstsc.exe: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("launched mstsc.exe (PID %d) connecting to %s", cmd.Process.Pid, peerIP)
|
||||
|
||||
// Don't wait for mstsc to exit - it runs independently
|
||||
go func() {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Debugf("mstsc.exe exited: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -150,6 +150,7 @@ func init() {
|
||||
rootCmd.AddCommand(logoutCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(rdpCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
@@ -25,10 +25,10 @@ func TestServiceParamsPath(t *testing.T) {
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
|
||||
configs.StateDir = "/var/lib/netbird"
|
||||
assert.Equal(t, "/var/lib/netbird/service.json", serviceParamsPath())
|
||||
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
|
||||
|
||||
configs.StateDir = "/custom/state"
|
||||
assert.Equal(t, "/custom/state/service.json", serviceParamsPath())
|
||||
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
|
||||
}
|
||||
|
||||
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +15,22 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMain intercepts when this test binary is run as a daemon subprocess.
|
||||
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
|
||||
// "service run ..." arguments. Since the test binary can't handle cobra CLI
|
||||
// args, it exits immediately, causing daemon -r to respawn rapidly until
|
||||
// hitting the rate limit and exiting. This makes service restart unreliable.
|
||||
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
|
||||
<-sig
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
|
||||
@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverRDPAllowedFlag).Changed {
|
||||
req.ServerRDPAllowed = &serverRDPAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -458,6 +461,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverRDPAllowedFlag).Changed {
|
||||
ic.ServerRDPAllowed = &serverRDPAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -582,6 +588,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverRDPAllowedFlag).Changed {
|
||||
loginRequest.ServerRDPAllowed = &serverRDPAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
@@ -35,20 +36,27 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
||||
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
||||
log.Info("forcing userspace firewall")
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
|
||||
// Kernel cannot fall back to anything else, need to return error
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
}
|
||||
|
||||
// Fall back to the userspace packet filter if native is unavailable
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||
@@ -160,3 +168,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func forceUserspaceFirewall() bool {
|
||||
val := os.Getenv(EnvForceUserspaceFirewall)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
force, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
|
||||
return false
|
||||
}
|
||||
return force
|
||||
}
|
||||
|
||||
@@ -7,6 +7,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
|
||||
// native iptables/nftables is available. This only applies when the WireGuard interface
|
||||
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
|
||||
// kernel netfilter rules.
|
||||
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
|
||||
@@ -33,7 +33,6 @@ type Manager struct {
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
@@ -64,10 +63,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
state := &ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
@@ -203,12 +201,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
@@ -286,6 +282,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
|
||||
@@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
@@ -43,6 +44,7 @@ const (
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpNatOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
@@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
|
||||
// Remove jump rules from built-in chains before deleting custom chains,
|
||||
// otherwise the chain deletion fails with "device or resource busy".
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
@@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
@@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNatOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
|
||||
@@ -9,10 +9,9 @@ import (
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
sync.Mutex
|
||||
|
||||
|
||||
@@ -169,6 +169,14 @@ type Manager interface {
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
|
||||
@@ -40,7 +40,6 @@ func getTableName() string {
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
@@ -106,10 +105,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -205,12 +203,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -346,6 +342,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
|
||||
@@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
|
||||
// just check on the local interface
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
@@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -8,10 +8,9 @@ import (
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
}
|
||||
|
||||
@@ -140,6 +140,17 @@ type Manager struct {
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[packetHook]
|
||||
tcpHookOut atomic.Pointer[packetHook]
|
||||
}
|
||||
|
||||
// packetHook stores a registered hook for a specific IP:port.
|
||||
type packetHook struct {
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
fn func([]byte) bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -594,6 +605,8 @@ func (m *Manager) resetState() {
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
m.tcpHookOut.Store(nil)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@@ -713,6 +726,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
@@ -895,38 +911,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
if h.ip == dstIP && h.port == dport {
|
||||
return h.fn(packetData)
|
||||
}
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1278,12 +1277,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.mgmtId, rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
@@ -1342,65 +1335,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return sourceMatched
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
udpHook: hook,
|
||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.udpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
m.udpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
}
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Check incoming hooks (stored in allow rules)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.tcpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
// Check outgoing hooks
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
m.tcpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddUDPPacketHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
}{
|
||||
{
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
{
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
}
|
||||
func TestSetUDPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
var called bool
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
h := manager.udpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(8000), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load())
|
||||
}
|
||||
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
if tt.dPort != addedRule.dPort.Values[0] {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||
return
|
||||
}
|
||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
func TestSetTCPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
var called bool
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
h := manager.tcpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(53), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||
assert.Nil(t, manager.tcpHookOut.Load())
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
|
||||
|
||||
if !found {
|
||||
t.Fatalf("The hook was not added properly.")
|
||||
}
|
||||
|
||||
// Now remove the packet hook
|
||||
err = manager.RemovePacketHook(hookID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove hook: %s", err)
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
|
||||
}
|
||||
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}
|
||||
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
manager.SetUDPPacketHook(
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
return true
|
||||
},
|
||||
)
|
||||
require.NotEmpty(t, hookID)
|
||||
|
||||
// Create test UDP packet
|
||||
ipv4 := &layers.IPv4{
|
||||
|
||||
@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||
// case where an interface comes up between refreshes.
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
|
||||
@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||
}
|
||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
|
||||
@@ -18,9 +18,7 @@ type PeerRule struct {
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
|
||||
udpHook func([]byte) bool
|
||||
drop bool
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
|
||||
@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
|
||||
return true // drop (intercepted by hook)
|
||||
})
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
|
||||
@@ -15,14 +15,17 @@ type PacketFilter interface {
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one UDP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one TCP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
||||
@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
// SetUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPacketFilter is a mock of PacketFilter interface.
|
||||
type MockPacketFilter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPacketFilterMockRecorder
|
||||
}
|
||||
|
||||
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
||||
type MockPacketFilterMockRecorder struct {
|
||||
mock *MockPacketFilter
|
||||
}
|
||||
|
||||
// NewMockPacketFilter creates a new mock instance.
|
||||
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
||||
mock := &MockPacketFilter{ctrl: ctrl}
|
||||
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
@@ -19,6 +19,9 @@ import (
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
@@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
func TestDefaultManagerStateless(t *testing.T) {
|
||||
// stateless currently only in userspace, so we have to disable kernel
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
@@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
// This tests the full ACL manager -> uspfilter integration.
|
||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
@@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
// up when they're removed from the network map in a subsequent update.
|
||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
@@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
// one added without leaking.
|
||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
@@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
@@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -44,6 +44,10 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// androidRunOverride is set on Android to inject mobile dependencies
|
||||
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
@@ -76,6 +80,9 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
if androidRunOverride != nil {
|
||||
return androidRunOverride(c, runningChan, logPath)
|
||||
}
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
@@ -104,6 +111,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -113,6 +121,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
@@ -534,6 +543,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerRDPAllowed: config.ServerRDPAllowed != nil && *config.ServerRDPAllowed,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
@@ -610,12 +620,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
||||
|
||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
@@ -634,12 +638,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
|
||||
73
client/internal/connect_android_default.go
Normal file
73
client/internal/connect_android_default.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
|
||||
// It returns an empty interface list, which means ICE P2P candidates won't be
|
||||
// discovered — connections will fall back to relay. Applications that need P2P
|
||||
// should provide a real implementation via runOnAndroidEmbed that uses
|
||||
// Android's ConnectivityManager to enumerate network interfaces.
|
||||
type noopIFaceDiscover struct{}
|
||||
|
||||
func (noopIFaceDiscover) IFaces() (string, error) {
|
||||
// Return empty JSON array — no local interfaces advertised for ICE.
|
||||
// This is intentional: without Android's ConnectivityManager, we cannot
|
||||
// reliably enumerate interfaces (netlink is restricted on Android 11+).
|
||||
// Relay connections still work; only P2P hole-punching is disabled.
|
||||
return "[]", nil
|
||||
}
|
||||
|
||||
// noopNetworkChangeListener is a stub for embed.Client on Android.
|
||||
// Network change events are ignored since the embed client manages its own
|
||||
// reconnection logic via the engine's built-in retry mechanism.
|
||||
type noopNetworkChangeListener struct{}
|
||||
|
||||
func (noopNetworkChangeListener) OnNetworkChanged(string) {
|
||||
// No-op: embed.Client relies on the engine's internal reconnection
|
||||
// logic rather than OS-level network change notifications.
|
||||
}
|
||||
|
||||
func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
||||
// No-op: in netstack mode, the overlay IP is managed by the userspace
|
||||
// network stack, not by OS-level interface configuration.
|
||||
}
|
||||
|
||||
// noopDnsReadyListener is a stub for embed.Client on Android.
|
||||
// DNS readiness notifications are not needed in netstack/embed mode
|
||||
// since system DNS is disabled and DNS resolution happens externally.
|
||||
type noopDnsReadyListener struct{}
|
||||
|
||||
func (noopDnsReadyListener) OnReady() {
|
||||
// No-op: embed.Client does not need DNS readiness notifications.
|
||||
// System DNS is disabled in netstack mode.
|
||||
}
|
||||
|
||||
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
|
||||
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||
|
||||
func init() {
|
||||
// Wire up the default override so embed.Client.Start() works on Android
|
||||
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// dependencies so the engine's existing Android code paths work unchanged.
|
||||
// Applications that need P2P ICE or real DNS should replace this by
|
||||
// setting androidRunOverride before calling Start().
|
||||
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||
return c.runOnAndroidEmbed(
|
||||
noopIFaceDiscover{},
|
||||
noopNetworkChangeListener{},
|
||||
[]netip.AddrPort{},
|
||||
noopDnsReadyListener{},
|
||||
runningChan,
|
||||
logPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
32
client/internal/connect_android_embed.go
Normal file
32
client/internal/connect_android_embed.go
Normal file
@@ -0,0 +1,32 @@
|
||||
//go:build android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||
// so embed.Client.Start() can detect when the engine is ready.
|
||||
// It provides complete MobileDependency so the engine's existing
|
||||
// Android code paths work unchanged.
|
||||
func (c *ConnectClient) runOnAndroidEmbed(
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
runningChan chan struct{},
|
||||
logPath string,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, runningChan, logPath)
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
@@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
|
||||
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
@@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addServiceParams(); err != nil {
|
||||
log.Errorf("failed to add service params to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addMetrics(); err != nil {
|
||||
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||
}
|
||||
@@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
serviceParamsFile = "service.json"
|
||||
serviceParamsBundle = "service_params.json"
|
||||
maskedValue = "***"
|
||||
envVarPrefix = "NB_"
|
||||
jsonKeyManagementURL = "management_url"
|
||||
jsonKeyServiceEnv = "service_env_vars"
|
||||
)
|
||||
|
||||
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
|
||||
|
||||
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
|
||||
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
|
||||
func (g *BundleGenerator) addServiceParams() error {
|
||||
path := filepath.Join(configs.StateDir, serviceParamsFile)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read service params: %w", err)
|
||||
}
|
||||
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||
return fmt.Errorf("parse service params: %w", err)
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
|
||||
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
|
||||
}
|
||||
}
|
||||
|
||||
g.sanitizeServiceEnvVars(params)
|
||||
|
||||
sanitizedData, err := json.MarshalIndent(params, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sanitized service params: %w", err)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
|
||||
return fmt.Errorf("add service params to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
|
||||
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
|
||||
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
|
||||
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
|
||||
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sanitized := make(map[string]any, len(envVars))
|
||||
for k, v := range envVars {
|
||||
val, _ := v.(string)
|
||||
switch {
|
||||
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
|
||||
sanitized[k] = maskedValue
|
||||
case g.anonymize:
|
||||
sanitized[k] = g.anonymizer.AnonymizeString(val)
|
||||
default:
|
||||
sanitized[k] = val
|
||||
}
|
||||
}
|
||||
params[jsonKeyServiceEnv] = sanitized
|
||||
}
|
||||
|
||||
// isSensitiveEnvVar returns true for env var names that may contain secrets.
|
||||
func isSensitiveEnvVar(key string) bool {
|
||||
lower := strings.ToLower(key)
|
||||
for _, s := range sensitiveEnvSubstrings {
|
||||
if strings.Contains(lower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -10,6 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSensitiveEnvVar(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
sensitive bool
|
||||
}{
|
||||
{"NB_SETUP_KEY", true},
|
||||
{"NB_API_TOKEN", true},
|
||||
{"NB_CLIENT_SECRET", true},
|
||||
{"NB_PASSWORD", true},
|
||||
{"NB_CREDENTIAL", true},
|
||||
{"NB_LOG_LEVEL", false},
|
||||
{"NB_MANAGEMENT_URL", false},
|
||||
{"NB_HOSTNAME", false},
|
||||
{"HOME", false},
|
||||
{"PATH", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anonymize bool
|
||||
input map[string]any
|
||||
check func(t *testing.T, params map[string]any)
|
||||
}{
|
||||
{
|
||||
name: "no env vars key",
|
||||
anonymize: false,
|
||||
input: map[string]any{"management_url": "https://mgmt.example.com"},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
|
||||
_, ok := params[jsonKeyServiceEnv]
|
||||
assert.False(t, ok, "service_env_vars should not be added")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sensitive NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safe NB vars anonymized when anonymize is true",
|
||||
anonymize: true,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
"NB_SETUP_KEY": "secret",
|
||||
"SOME_OTHER": "val",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
// Safe NB_ values should be anonymized (not the original, not masked)
|
||||
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
|
||||
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
|
||||
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
|
||||
|
||||
logVal := env["NB_LOG_LEVEL"].(string)
|
||||
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
|
||||
|
||||
// Sensitive and non-NB_ still masked
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
|
||||
assert.Equal(t, maskedValue, env["SOME_OTHER"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
g := &BundleGenerator{
|
||||
anonymize: tt.anonymize,
|
||||
anonymizer: anonymizer,
|
||||
}
|
||||
g.sanitizeServiceEnvVars(tt.input)
|
||||
tt.check(t, tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddServiceParams(t *testing.T) {
|
||||
t.Run("missing service.json returns nil", func(t *testing.T) {
|
||||
g := &BundleGenerator{
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
}
|
||||
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = t.TempDir()
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
err := g.addServiceParams()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_LOG_LEVEL": "trace",
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: true,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, zr.File, 1)
|
||||
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
mgmt := result[jsonKeyManagementURL].(string)
|
||||
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
|
||||
assert.NotEmpty(t, mgmt)
|
||||
|
||||
env := result[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
|
||||
})
|
||||
|
||||
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: false,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to check if IP is in CGNAT range
|
||||
func isInCGNATRange(ip net.IP) bool {
|
||||
cgnat := net.IPNet{
|
||||
|
||||
@@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
if m.MsgHdr.Truncated {
|
||||
w.SetMeta("truncated", "true")
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
@@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
|
||||
@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_Stop tests cleanup on Stop
|
||||
// TestLocalResolver_Stop tests cleanup on GracefullyStop
|
||||
func TestLocalResolver_Stop(t *testing.T) {
|
||||
t.Run("Stop clears all state", func(t *testing.T) {
|
||||
t.Run("GracefullyStop clears all state", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
|
||||
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
resolver.Stop()
|
||||
})
|
||||
|
||||
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
lookupStarted := make(chan struct{})
|
||||
|
||||
@@ -90,6 +90,11 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// SetFirewall mock implementation of SetFirewall from Server interface
|
||||
func (m *MockServer) SetFirewall(Firewall) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
|
||||
@@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) {
|
||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||
func (r *responseWriter) Hijack() {
|
||||
}
|
||||
|
||||
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
|
||||
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
|
||||
var srcIP net.IP
|
||||
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
|
||||
srcIP = ipv4.(*layers.IPv4).SrcIP
|
||||
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
|
||||
srcIP = ipv6.(*layers.IPv6).SrcIP
|
||||
}
|
||||
|
||||
var srcPort int
|
||||
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
|
||||
srcPort = int(udp.(*layers.UDP).SrcPort)
|
||||
}
|
||||
|
||||
if srcIP == nil {
|
||||
return nil
|
||||
}
|
||||
return &net.UDPAddr{IP: srcIP, Port: srcPort}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,7 @@ type Server interface {
|
||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||
SetRouteChecker(func(netip.Addr) bool)
|
||||
SetFirewall(Firewall)
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
@@ -151,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
|
||||
if config.WgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
|
||||
}
|
||||
|
||||
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||
@@ -186,11 +187,16 @@ func NewDefaultServerIos(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
hostsDnsList []netip.AddrPort,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
return ds
|
||||
}
|
||||
|
||||
@@ -374,6 +380,17 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// SetFirewall sets the firewall used for DNS port DNAT rules.
|
||||
// This must be called before Initialize when using the listener-based service,
|
||||
// because the firewall is typically not available at construction time.
|
||||
func (s *DefaultServer) SetFirewall(fw Firewall) {
|
||||
if svc, ok := s.service.(*serviceViaListener); ok {
|
||||
svc.listenerFlagLock.Lock()
|
||||
svc.firewall = fw
|
||||
svc.listenerFlagLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.probeMu.Lock()
|
||||
@@ -395,8 +412,12 @@ func (s *DefaultServer) Stop() {
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) disableDNS() error {
|
||||
defer s.service.Stop()
|
||||
func (s *DefaultServer) disableDNS() (retErr error) {
|
||||
defer func() {
|
||||
if err := s.service.Stop(); err != nil {
|
||||
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
if s.isUsingNoopHostManager() {
|
||||
return nil
|
||||
|
||||
@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
|
||||
type mockService struct{}
|
||||
|
||||
func (m *mockService) Listen() error { return nil }
|
||||
func (m *mockService) Stop() {}
|
||||
func (m *mockService) Stop() error { return nil }
|
||||
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
|
||||
func (m *mockService) RuntimePort() int { return 53 }
|
||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
|
||||
@@ -4,15 +4,25 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPort = 53
|
||||
)
|
||||
|
||||
// Firewall provides DNAT capabilities for DNS port redirection.
|
||||
// This is used when the DNS server cannot bind port 53 directly
|
||||
// and needs firewall rules to redirect traffic.
|
||||
type Firewall interface {
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop()
|
||||
Stop() error
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
|
||||
@@ -10,9 +10,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
@@ -31,25 +35,33 @@ type serviceViaListener struct {
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
tcpServer *dns.Server
|
||||
listenIP netip.Addr
|
||||
listenPort uint16
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
firewall Firewall
|
||||
tcpDNATConfigured bool
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
firewall: fw,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
tcpServer: &dns.Server{
|
||||
Net: "tcp",
|
||||
Handler: mux,
|
||||
},
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error {
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.listenIP = s.listenIP.Unmap()
|
||||
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
s.server.Addr = addr
|
||||
s.tcpServer.Addr = addr
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||
log.Debugf("starting dns on %s (UDP + TCP)", addr)
|
||||
s.listenerIsRunning = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
|
||||
s.listenerFlagLock.Lock()
|
||||
unexpected := s.listenerIsRunning
|
||||
s.listenerIsRunning = false
|
||||
s.listenerFlagLock.Unlock()
|
||||
|
||||
if unexpected {
|
||||
if err := s.tcpServer.Shutdown(); err != nil {
|
||||
log.Debugf("failed to shutdown DNS TCP server: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
|
||||
// a DNAT rule because eBPF only handles UDP.
|
||||
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
|
||||
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
|
||||
} else {
|
||||
s.tcpDNATConfigured = true
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() {
|
||||
func (s *serviceViaListener) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
s.listenerIsRunning = false
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := s.server.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
|
||||
}
|
||||
|
||||
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
|
||||
}
|
||||
|
||||
if s.tcpDNATConfigured && s.firewall != nil {
|
||||
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
s.tcpDNATConfigured = false
|
||||
}
|
||||
|
||||
if s.ebpfService != nil {
|
||||
err = s.ebpfService.FreeDNSFwd()
|
||||
if err != nil {
|
||||
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||
if err := s.ebpfService.FreeDNSFwd(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
||||
return s.listenIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
@@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
||||
|
||||
udpAddr := net.UDPAddrFromAddrPort(addrPort)
|
||||
udpLn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
if err := udpLn.Close(); err != nil {
|
||||
log.Debugf("close UDP probe listener: %s", err)
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
if err := tcpLn.Close(); err != nil {
|
||||
log.Debugf("close TCP probe listener: %s", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
86
client/internal/dns/service_listener_test.go
Normal file
86
client/internal/dns/service_listener_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("192.0.2.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a service using a custom address to avoid needing root
|
||||
svc := newServiceViaListener(nil, nil, nil)
|
||||
svc.dnsMux.Handle(".", handler)
|
||||
|
||||
// Bind both transports up front to avoid TOCTOU races.
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Skip("cannot bind to 127.0.0.153, skipping")
|
||||
}
|
||||
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
t.Skip("cannot bind TCP on same port, skipping")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", customIP, port)
|
||||
svc.server.PacketConn = udpConn
|
||||
svc.tcpServer.Listener = tcpLn
|
||||
svc.listenIP = customIP
|
||||
svc.listenPort = port
|
||||
|
||||
go func() {
|
||||
if err := svc.server.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := svc.tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
svc.listenerIsRunning = true
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, svc.Stop())
|
||||
}()
|
||||
|
||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Test UDP query
|
||||
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
|
||||
udpResp, _, err := udpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "UDP query should succeed")
|
||||
require.NotNil(t, udpResp)
|
||||
require.NotEmpty(t, udpResp.Answer)
|
||||
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
|
||||
|
||||
// Test TCP query
|
||||
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
|
||||
tcpResp, _, err := tcpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "TCP query should succeed")
|
||||
require.NotNil(t, tcpResp)
|
||||
require.NotEmpty(t, tcpResp.Answer)
|
||||
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -18,7 +20,8 @@ type ServiceViaMemory struct {
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP netip.Addr
|
||||
runtimePort int
|
||||
udpFilterHookID string
|
||||
tcpDNS *tcpDNSServer
|
||||
tcpHookSet bool
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
@@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
s := &ServiceViaMemory{
|
||||
|
||||
return &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: lastIP,
|
||||
runtimePort: DefaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
@@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return fmt.Errorf("filter dns traffice: %w", err)
|
||||
if err := s.filterDNSTraffic(); err != nil {
|
||||
return fmt.Errorf("filter dns traffic: %w", err)
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
@@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
func (s *ServiceViaMemory) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter != nil {
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
if s.tcpHookSet {
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
s.tcpDNS.Stop()
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() error {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
return errors.New("DNS filter not initialized")
|
||||
}
|
||||
|
||||
// Create TCP DNS server lazily here since the device may not exist at construction time.
|
||||
if s.tcpDNS == nil {
|
||||
if dev := s.wgInterface.GetDevice(); dev != nil {
|
||||
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
|
||||
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
|
||||
}
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
@@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
if udpLayer == nil {
|
||||
return true
|
||||
}
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
@@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
dev := s.wgInterface.GetDevice()
|
||||
if dev == nil {
|
||||
return true
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
|
||||
writer := &responseWriter{
|
||||
remote: remoteAddrFromPacket(packet),
|
||||
packet: packet,
|
||||
device: dev.Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
tcpHook := func(packetData []byte) bool {
|
||||
s.tcpDNS.InjectPacket(packetData)
|
||||
return true
|
||||
}
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
|
||||
s.tcpHookSet = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
444
client/internal/dns/tcpstack.go
Normal file
444
client/internal/dns/tcpstack.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTCPReceiveWindow = 8192
|
||||
dnsTCPMaxInFlight = 16
|
||||
dnsTCPIdleTimeout = 30 * time.Second
|
||||
dnsTCPReadTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
|
||||
// It is started lazily when a truncated DNS response is detected and shuts down
|
||||
// after a period of inactivity to conserve resources.
|
||||
type tcpDNSServer struct {
|
||||
mu sync.Mutex
|
||||
s *stack.Stack
|
||||
ep *dnsEndpoint
|
||||
mux *dns.ServeMux
|
||||
tunDev tun.Device
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
mtu uint16
|
||||
|
||||
running bool
|
||||
closed bool
|
||||
timerID uint64
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
|
||||
return &tcpDNSServer{
|
||||
mux: mux,
|
||||
tunDev: tunDev,
|
||||
ip: ip,
|
||||
port: port,
|
||||
mtu: mtu,
|
||||
}
|
||||
}
|
||||
|
||||
// InjectPacket ensures the stack is running and delivers a raw IP packet into
|
||||
// the gvisor stack for TCP processing. Combining both operations under a single
|
||||
// lock prevents a race where the idle timer could stop the stack between
|
||||
// start and delivery.
|
||||
func (t *tcpDNSServer) InjectPacket(payload []byte) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if !t.running {
|
||||
if err := t.startLocked(); err != nil {
|
||||
log.Errorf("failed to start TCP DNS stack: %v", err)
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
|
||||
}
|
||||
t.resetTimerLocked()
|
||||
|
||||
ep := t.ep
|
||||
if ep == nil || ep.dispatcher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
|
||||
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
|
||||
// Stop tears down the gvisor stack and releases resources permanently.
|
||||
// After Stop, InjectPacket becomes a no-op.
|
||||
func (t *tcpDNSServer) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.stopLocked()
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) startLocked() error {
|
||||
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
nicID := tcpip.NICID(1)
|
||||
ep := &dnsEndpoint{
|
||||
tunDev: t.tunDev,
|
||||
}
|
||||
ep.mtu.Store(uint32(t.mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, ep); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
|
||||
PrefixLen: 32,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create default subnet: %w", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
})
|
||||
|
||||
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
|
||||
t.handleTCPDNS(r)
|
||||
})
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
||||
|
||||
t.s = s
|
||||
t.ep = ep
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) stopLocked() {
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
t.timer = nil
|
||||
}
|
||||
|
||||
if t.s != nil {
|
||||
t.s.Close()
|
||||
t.s.Wait()
|
||||
t.s = nil
|
||||
}
|
||||
t.ep = nil
|
||||
t.running = false
|
||||
|
||||
log.Debugf("TCP DNS stack stopped")
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) resetTimerLocked() {
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
}
|
||||
t.timerID++
|
||||
id := t.timerID
|
||||
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Only stop if this timer is still the active one.
|
||||
// A racing InjectPacket may have replaced it.
|
||||
if t.timerID != id {
|
||||
return
|
||||
}
|
||||
t.stopLocked()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
|
||||
conn := gonet.NewTCPConn(&wq, ep)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Tracef("TCP DNS: close conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Reset idle timer on activity
|
||||
t.mu.Lock()
|
||||
t.resetTimerLocked()
|
||||
t.mu.Unlock()
|
||||
|
||||
localAddr := &net.TCPAddr{
|
||||
IP: id.LocalAddress.AsSlice(),
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
remoteAddr := &net.TCPAddr{
|
||||
IP: id.RemoteAddress.AsSlice(),
|
||||
Port: int(id.RemotePort),
|
||||
}
|
||||
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
|
||||
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
|
||||
break
|
||||
}
|
||||
|
||||
msg, err := readTCPDNSMessage(conn)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
writer := &tcpResponseWriter{
|
||||
conn: conn,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
t.mux.ServeDNS(writer, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
|
||||
type dnsEndpoint struct {
|
||||
dispatcher stack.NetworkDispatcher
|
||||
tunDev tun.Device
|
||||
mtu atomic.Uint32
|
||||
}
|
||||
|
||||
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
|
||||
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
|
||||
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
|
||||
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
|
||||
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
|
||||
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
|
||||
func (e *dnsEndpoint) Wait() { /* no async work */ }
|
||||
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
|
||||
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
|
||||
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
|
||||
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
|
||||
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
|
||||
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
|
||||
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
|
||||
|
||||
const tunPacketOffset = 40
|
||||
|
||||
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
raw := data.AsSlice()
|
||||
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
|
||||
buf = append(buf, raw...)
|
||||
data.Release()
|
||||
|
||||
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
|
||||
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
|
||||
type tcpResponseWriter struct {
|
||||
conn *gonet.TCPConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) LocalAddr() net.Addr {
|
||||
return w.localAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.remoteAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pack: %w", err)
|
||||
}
|
||||
|
||||
// DNS TCP: 2-byte length prefix + message
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
|
||||
if _, err = w.conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
if _, err := w.conn.Write(buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Close() error {
|
||||
return w.conn.Close()
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) TsigStatus() error { return nil }
|
||||
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
|
||||
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
|
||||
|
||||
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
|
||||
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
|
||||
// DNS over TCP uses a 2-byte length prefix
|
||||
lenBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return nil, fmt.Errorf("read length: %w", err)
|
||||
}
|
||||
|
||||
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if msgLen == 0 || msgLen > 65535 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
|
||||
msgBuf := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(conn, msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("read message: %w", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("unpack: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
|
||||
// Supports both IPv4 and IPv6.
|
||||
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
|
||||
if len(pkt) == 0 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcIP, transportOffset := srcIPFromPacket(pkt)
|
||||
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
|
||||
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
|
||||
}
|
||||
|
||||
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
|
||||
switch header.IPVersion(pkt) {
|
||||
case 4:
|
||||
return srcIPv4(pkt)
|
||||
case 6:
|
||||
return srcIPv6(pkt)
|
||||
default:
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
}
|
||||
|
||||
func srcIPv4(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv4MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv4(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, int(hdr.HeaderLength())
|
||||
}
|
||||
|
||||
func srcIPv6(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv6MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv6(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, header.IPv6MinimumSize
|
||||
}
|
||||
@@ -41,10 +41,61 @@ const (
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
protoTCP = "tcp"
|
||||
)
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
}
|
||||
|
||||
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
|
||||
func dnsProtocolFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
|
||||
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
|
||||
func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
|
||||
r.protocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
@@ -138,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
// Propagate inbound protocol so upstream exchange can use TCP directly
|
||||
// when the request came in over TCP.
|
||||
ctx := u.ctx
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
network := addr.Network()
|
||||
ctx = contextWithDNSProtocol(ctx, network)
|
||||
resutil.SetMeta(w, "protocol", network)
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
@@ -153,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -168,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
@@ -178,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
}()
|
||||
@@ -203,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -220,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
@@ -428,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
return int(opt.UDPSize())
|
||||
}
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// MTU - ip + udp headers
|
||||
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
|
||||
client.UDPSize = uint16(currentMTU - (60 + 8))
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
// Note: the query could be sent out on an interface that is not ours,
|
||||
// but higher MTU settings could break truncation handling.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
client.UDPSize = maxUDPPayload
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
@@ -453,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
}
|
||||
|
||||
if rm == nil || !rm.MsgHdr.Truncated {
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
// TODO: if the upstream's truncated UDP response already contains more
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
client.Net = "tcp"
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
|
||||
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
@@ -479,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
// If request came in over TCP, go straight to TCP upstream
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than what we can read over UDP.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
@@ -511,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
upstreamExchangeClient := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
|
||||
@@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSProtocolContext(t *testing.T) {
|
||||
t.Run("roundtrip udp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
|
||||
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("roundtrip tcp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("missing returns empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContext(t *testing.T) {
|
||||
// Start a local DNS server that responds on TCP only
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Addr: "127.0.0.1:0",
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tcpServer.Listener = tcpLn
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// With TCP context, should connect directly via TCP without trying UDP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
|
||||
// UDP handler returns a truncated response to trigger TCP retry.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// TCP handler returns the full answer.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.3"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{
|
||||
PacketConn: udpPC,
|
||||
Net: "udp",
|
||||
Handler: udpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := udpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
|
||||
assert.False(t, rm.Truncated, "TCP response should not be truncated")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
|
||||
// Start only a TCP server (no UDP). With TCP context it should succeed.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.2"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// TCP context: should skip UDP entirely and go directly to TCP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
|
||||
|
||||
// Without TCP context, trying to reach a TCP-only server via UDP should fail
|
||||
ctx2 := context.Background()
|
||||
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
|
||||
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
|
||||
assert.Error(t, err, "should fail when no UDP server and no TCP context")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() { _ = udpServer.Shutdown() })
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
||||
// When the client advertises a large EDNS0 (4096) and the upstream
|
||||
// truncates, the TCP response should NOT be truncated since the full
|
||||
// answer fits within the client's original buffer.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
// Add enough records to exceed MTU but fit within 4096
|
||||
for i := range 20 {
|
||||
m.Answer = append(m.Answer, &dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
|
||||
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
|
||||
})
|
||||
}
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
|
||||
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
go func() { _ = tcpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
|
||||
// Client with large buffer: should get all records without truncation
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
|
||||
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
|
||||
|
||||
// Client with small buffer: should get truncated response
|
||||
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r2.SetEdns0(512, false)
|
||||
|
||||
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm2)
|
||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
||||
}
|
||||
|
||||
@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
@@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
@@ -116,6 +117,7 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerRDPAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -196,6 +198,7 @@ type Engine struct {
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
rdpServer rdpServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -210,9 +213,10 @@ type Engine struct {
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
portForwardManager *portforward.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
@@ -259,26 +263,27 @@ func NewEngine(
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
portForwardManager: portforward.NewManager(),
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -500,7 +505,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
||||
for _, r := range routes {
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
@@ -521,6 +526,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return err
|
||||
}
|
||||
|
||||
// Inject firewall into DNS server now that it's available.
|
||||
// The DNS server is created before the firewall because the route manager
|
||||
// depends on the DNS server, and the firewall depends on the wg interface.
|
||||
e.dnsServer.SetFirewall(e.firewall)
|
||||
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
@@ -532,6 +542,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Start after interface is up since port may have been resolved from 0 or changed if occupied
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
|
||||
}()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -1021,6 +1038,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateRDP(); err != nil {
|
||||
log.Warnf("failed handling RDP server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
@@ -1308,6 +1329,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
// Reuse SSH ACL for RDP authorization
|
||||
e.updateRDPServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1535,12 +1559,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
PortForwardManager: e.portForwardManager,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1697,6 +1722,12 @@ func (e *Engine) close() {
|
||||
if e.rpManager != nil {
|
||||
_ = e.rpManager.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
|
||||
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||
@@ -1800,7 +1831,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
return dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
@@ -1837,6 +1868,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
// IsBlockInbound returns whether inbound connections are blocked.
|
||||
func (e *Engine) IsBlockInbound() bool {
|
||||
return e.config.BlockInbound
|
||||
}
|
||||
|
||||
// GetClientMetrics returns the client metrics
|
||||
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
|
||||
191
client/internal/engine_rdp.go
Normal file
191
client/internal/engine_rdp.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
type rdpServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort) error
|
||||
Stop() error
|
||||
GetPendingStore() *rdpserver.PendingStore
|
||||
UpdateRDPAuth(config *sshauth.Config)
|
||||
}
|
||||
|
||||
func (e *Engine) setupRDPPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
|
||||
return fmt.Errorf("add RDP auth port redirection: %w", err)
|
||||
}
|
||||
log.Infof("RDP auth port redirection enabled: %s:%d -> %s:%d",
|
||||
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupRDPPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
|
||||
return fmt.Errorf("remove RDP auth port redirection: %w", err)
|
||||
}
|
||||
log.Debugf("RDP auth port redirection removed: %s:%d -> %s:%d",
|
||||
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateRDP handles starting/stopping the RDP server based on the config flag.
|
||||
func (e *Engine) updateRDP() error {
|
||||
if !e.config.ServerRDPAllowed {
|
||||
if e.rdpServer != nil {
|
||||
log.Info("RDP passthrough disabled, stopping RDP auth server")
|
||||
}
|
||||
return e.stopRDPServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("RDP server is disabled because inbound connections are blocked")
|
||||
return e.stopRDPServer()
|
||||
}
|
||||
|
||||
if e.rdpServer != nil {
|
||||
log.Debug("RDP auth server is already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startRDPServer()
|
||||
}
|
||||
|
||||
func (e *Engine) startRDPServer() error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
|
||||
cfg := &rdpserver.Config{
|
||||
NetworkAddr: wgAddr.Network,
|
||||
}
|
||||
|
||||
server := rdpserver.New(cfg)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, rdpserver.InternalRDPAuthPort)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
return fmt.Errorf("start RDP auth server: %w", err)
|
||||
}
|
||||
|
||||
e.rdpServer = server
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
|
||||
log.Debugf("registered RDP auth service with netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupRDPPortRedirection(); err != nil {
|
||||
log.Warnf("failed to setup RDP auth port redirection: %v", err)
|
||||
}
|
||||
|
||||
// Register the credential provider DLL dynamically (Windows only)
|
||||
if err := rdpserver.RegisterCredentialProvider(); err != nil {
|
||||
log.Warnf("failed to register RDP credential provider (passwordless RDP will not work): %v", err)
|
||||
}
|
||||
|
||||
log.Info("RDP passthrough enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopRDPServer() error {
|
||||
if e.rdpServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupRDPPortRedirection(); err != nil {
|
||||
log.Warnf("failed to cleanup RDP auth port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
|
||||
log.Debugf("unregistered RDP auth service from netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
|
||||
}
|
||||
}
|
||||
|
||||
// Unregister the credential provider DLL (Windows only)
|
||||
if err := rdpserver.UnregisterCredentialProvider(); err != nil {
|
||||
log.Warnf("failed to unregister RDP credential provider: %v", err)
|
||||
}
|
||||
|
||||
log.Info("stopping RDP auth server")
|
||||
err := e.rdpServer.Stop()
|
||||
e.rdpServer = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateRDPServerAuth reuses the SSH authorization config for RDP access control.
|
||||
// This means the same user/machine-user mappings that control SSH access also control RDP.
|
||||
func (e *Engine) updateRDPServerAuth(sshAuth *mgmProto.SSHAuth) {
|
||||
if sshAuth == nil || e.rdpServer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := sshAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid hash length %d, expected 16 - skipping RDP server auth update", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range sshAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshAuth.GetUserIDClaim(),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
}
|
||||
|
||||
e.rdpServer.UpdateRDPAuth(authConfig)
|
||||
}
|
||||
@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey, err := mgmtClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -45,6 +46,7 @@ type ServiceDependencies struct {
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
PortForwardManager *portforward.Manager
|
||||
MetricsRecorder MetricsRecorder
|
||||
}
|
||||
|
||||
@@ -87,16 +89,17 @@ type ConnConfig struct {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
portForwardManager *portforward.Manager
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string)
|
||||
@@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
|
||||
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
portForwardManager: services.PortForwardManager,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -61,6 +62,9 @@ type WorkerICE struct {
|
||||
|
||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||
lastKnownState ice.ConnectionState
|
||||
|
||||
// portForwardAttempted tracks if we've already tried port forwarding this session
|
||||
portForwardAttempted bool
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
||||
@@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
|
||||
}
|
||||
|
||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||
w.portForwardAttempted = false
|
||||
|
||||
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create agent: %w", err)
|
||||
@@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
||||
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if candidate.Type() == ice.CandidateTypeServerReflexive {
|
||||
w.injectPortForwardedCandidate(candidate)
|
||||
}
|
||||
}
|
||||
|
||||
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
|
||||
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
|
||||
pfManager := w.conn.portForwardManager
|
||||
if pfManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
mapping := pfManager.GetMapping()
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
if w.portForwardAttempted {
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.portForwardAttempted = true
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
|
||||
if err != nil {
|
||||
w.log.Warnf("create forwarded candidate: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
|
||||
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
|
||||
|
||||
go func() {
|
||||
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
|
||||
w.log.Errorf("signal port-forwarded candidate: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
|
||||
// It uses the NAT gateway's external IP with the forwarded port.
|
||||
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
|
||||
var externalIP string
|
||||
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
|
||||
externalIP = mapping.ExternalIP.String()
|
||||
} else {
|
||||
// Fallback to STUN-discovered address if NAT didn't provide external IP
|
||||
externalIP = srflxCandidate.Address()
|
||||
}
|
||||
|
||||
// Per RFC 8445, the related address for srflx is the base (host candidate address).
|
||||
// If the original srflx has unspecified related address, use its own address as base.
|
||||
relAddr := srflxCandidate.RelatedAddress().Address
|
||||
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
|
||||
relAddr = srflxCandidate.Address()
|
||||
}
|
||||
|
||||
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
|
||||
// over regular srflx during ICE connectivity checks.
|
||||
priority := srflxCandidate.Priority() + 1000
|
||||
|
||||
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: srflxCandidate.NetworkType().String(),
|
||||
Address: externalIP,
|
||||
Port: int(mapping.ExternalPort),
|
||||
Component: srflxCandidate.Component(),
|
||||
Priority: priority,
|
||||
RelAddr: relAddr,
|
||||
RelPort: int(mapping.InternalPort),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create candidate: %w", err)
|
||||
}
|
||||
|
||||
for _, e := range srflxCandidate.Extensions() {
|
||||
if e.Key == ice.ExtensionKeyCandidateID {
|
||||
e.Value = srflxCandidate.ID()
|
||||
}
|
||||
if err := candidate.AddExtension(e); err != nil {
|
||||
return nil, fmt.Errorf("add extension: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return candidate, nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||
@@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
||||
if !lok || !rok {
|
||||
continue
|
||||
}
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
|
||||
sessionID,
|
||||
local.NetworkType(), local.Type(), local.Address(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||
local.NetworkType(), local.Type(), local.Address(), local.Port(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
|
||||
stat.CurrentRoundTripTime*1000)
|
||||
}
|
||||
}
|
||||
|
||||
26
client/internal/portforward/env.go
Normal file
26
client/internal/portforward/env.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
|
||||
)
|
||||
|
||||
func isDisabledByEnv() bool {
|
||||
val := os.Getenv(envDisableNATMapper)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
280
client/internal/portforward/manager.go
Normal file
280
client/internal/portforward/manager.go
Normal file
@@ -0,0 +1,280 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-nat"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMappingTTL = 2 * time.Hour
|
||||
discoveryTimeout = 10 * time.Second
|
||||
mappingDescription = "NetBird"
|
||||
)
|
||||
|
||||
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
|
||||
// allowing for whitespace/newlines between tags from different router firmware.
|
||||
var upnpErrPermanentLeaseOnly = regexp.MustCompile(`<errorCode>\s*725\s*</errorCode>`)
|
||||
|
||||
// Mapping represents an active NAT port mapping.
|
||||
type Mapping struct {
|
||||
Protocol string
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP net.IP
|
||||
NATType string
|
||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// TODO: persist mapping state for crash recovery cleanup of permanent leases.
|
||||
// Currently not done because State.Cleanup requires NAT gateway re-discovery,
|
||||
// which blocks startup for ~10s when no gateway is present (affects all clients).
|
||||
|
||||
type Manager struct {
|
||||
cancel context.CancelFunc
|
||||
|
||||
mapping *Mapping
|
||||
mappingLock sync.Mutex
|
||||
|
||||
wgPort uint16
|
||||
|
||||
done chan struct{}
|
||||
stopCtx chan context.Context
|
||||
|
||||
// protect exported functions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new port forwarding manager.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
stopCtx: make(chan context.Context, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
|
||||
m.mu.Lock()
|
||||
if m.cancel != nil {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if isDisabledByEnv() {
|
||||
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if wgPort == 0 {
|
||||
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.wgPort = wgPort
|
||||
|
||||
m.done = make(chan struct{})
|
||||
defer close(m.done)
|
||||
|
||||
ctx, m.cancel = context.WithCancel(ctx)
|
||||
m.mu.Unlock()
|
||||
|
||||
gateway, mapping, err := m.setup(ctx)
|
||||
if err != nil {
|
||||
log.Infof("port forwarding setup: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mappingLock.Lock()
|
||||
m.mapping = mapping
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
m.renewLoop(ctx, gateway, mapping.TTL)
|
||||
|
||||
select {
|
||||
case cleanupCtx := <-m.stopCtx:
|
||||
// block the Start while cleaned up gracefully
|
||||
m.cleanup(cleanupCtx, gateway)
|
||||
default:
|
||||
// return Start immediately and cleanup in background
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
go func() {
|
||||
defer cleanupCancel()
|
||||
m.cleanup(cleanupCtx, gateway)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// GetMapping returns the current mapping if ready, nil otherwise
|
||||
func (m *Manager) GetMapping() *Mapping {
|
||||
m.mappingLock.Lock()
|
||||
defer m.mappingLock.Unlock()
|
||||
|
||||
if m.mapping == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mapping := *m.mapping
|
||||
return &mapping
|
||||
}
|
||||
|
||||
// GracefullyStop cancels the manager and attempts to delete the port mapping.
|
||||
// After GracefullyStop returns, the manager cannot be restarted.
|
||||
func (m *Manager) GracefullyStop(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
|
||||
m.startTearDown(ctx)
|
||||
|
||||
m.cancel()
|
||||
m.cancel = nil
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
|
||||
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
|
||||
defer discoverCancel()
|
||||
|
||||
gateway, err := nat.DiscoverGateway(discoverCtx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("discover gateway: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("discovered NAT gateway: %s", gateway.Type())
|
||||
|
||||
mapping, err := m.createMapping(ctx, gateway)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create port mapping: %w", err)
|
||||
}
|
||||
return gateway, mapping, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ttl := defaultMappingTTL
|
||||
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
||||
if err != nil {
|
||||
if !isPermanentLeaseRequired(err) {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
|
||||
ttl = 0
|
||||
externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
externalIP, err := gateway.GetExternalAddress()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get external address: %v", err)
|
||||
// todo return with err?
|
||||
}
|
||||
|
||||
mapping := &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: m.wgPort,
|
||||
ExternalPort: uint16(externalPort),
|
||||
ExternalIP: externalIP,
|
||||
NATType: gateway.Type(),
|
||||
TTL: ttl,
|
||||
}
|
||||
|
||||
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
|
||||
m.wgPort, externalPort, gateway.Type(), externalIP)
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
|
||||
if ttl == 0 {
|
||||
// Permanent mappings don't expire, just wait for cancellation.
|
||||
<-ctx.Done()
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := m.renewMapping(ctx, gateway); err != nil {
|
||||
log.Warnf("failed to renew port mapping: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add port mapping: %w", err)
|
||||
}
|
||||
|
||||
if uint16(externalPort) != m.mapping.ExternalPort {
|
||||
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
|
||||
m.mappingLock.Lock()
|
||||
m.mapping.ExternalPort = uint16(externalPort)
|
||||
m.mappingLock.Unlock()
|
||||
}
|
||||
|
||||
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
|
||||
m.mappingLock.Lock()
|
||||
mapping := m.mapping
|
||||
m.mapping = nil
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
|
||||
log.Warnf("delete port mapping on stop: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
|
||||
}
|
||||
|
||||
func (m *Manager) startTearDown(ctx context.Context) {
|
||||
select {
|
||||
case m.stopCtx <- ctx:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
|
||||
func isPermanentLeaseRequired(err error) bool {
|
||||
return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
|
||||
}
|
||||
39
client/internal/portforward/manager_js.go
Normal file
39
client/internal/portforward/manager_js.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mapping represents an active NAT port mapping.
|
||||
type Mapping struct {
|
||||
Protocol string
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP net.IP
|
||||
NATType string
|
||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
|
||||
type Manager struct{}
|
||||
|
||||
// NewManager returns a stub manager for js/wasm builds.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{}
|
||||
}
|
||||
|
||||
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
|
||||
func (m *Manager) Start(context.Context, uint16) {
|
||||
// no NAT traversal in wasm
|
||||
}
|
||||
|
||||
// GracefullyStop is a no-op on js/wasm.
|
||||
func (m *Manager) GracefullyStop(context.Context) error { return nil }
|
||||
|
||||
// GetMapping always returns nil on js/wasm.
|
||||
func (m *Manager) GetMapping() *Mapping {
|
||||
return nil
|
||||
}
|
||||
201
client/internal/portforward/manager_test.go
Normal file
201
client/internal/portforward/manager_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockNAT struct {
|
||||
natType string
|
||||
deviceAddr net.IP
|
||||
externalAddr net.IP
|
||||
internalAddr net.IP
|
||||
mappings map[int]int
|
||||
addMappingErr error
|
||||
deleteMappingErr error
|
||||
onlyPermanentLeases bool
|
||||
lastTimeout time.Duration
|
||||
}
|
||||
|
||||
func newMockNAT() *mockNAT {
|
||||
return &mockNAT{
|
||||
natType: "Mock-NAT",
|
||||
deviceAddr: net.ParseIP("192.168.1.1"),
|
||||
externalAddr: net.ParseIP("203.0.113.50"),
|
||||
internalAddr: net.ParseIP("192.168.1.100"),
|
||||
mappings: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockNAT) Type() string {
|
||||
return m.natType
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
|
||||
return m.deviceAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
|
||||
return m.externalAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
|
||||
return m.internalAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
|
||||
if m.addMappingErr != nil {
|
||||
return 0, m.addMappingErr
|
||||
}
|
||||
if m.onlyPermanentLeases && timeout != 0 {
|
||||
return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: <UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\"><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>")
|
||||
}
|
||||
externalPort := internalPort
|
||||
m.mappings[internalPort] = externalPort
|
||||
m.lastTimeout = timeout
|
||||
return externalPort, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
||||
if m.deleteMappingErr != nil {
|
||||
return m.deleteMappingErr
|
||||
}
|
||||
delete(m.mappings, internalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManager_CreateMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.wgPort = 51820
|
||||
|
||||
gateway := newMockNAT()
|
||||
mapping, err := m.createMapping(context.Background(), gateway)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mapping)
|
||||
|
||||
assert.Equal(t, "udp", mapping.Protocol)
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
assert.Equal(t, uint16(51820), mapping.ExternalPort)
|
||||
assert.Equal(t, "Mock-NAT", mapping.NATType)
|
||||
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
|
||||
assert.Equal(t, defaultMappingTTL, mapping.TTL)
|
||||
}
|
||||
|
||||
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
|
||||
m := NewManager()
|
||||
assert.Nil(t, m.GetMapping())
|
||||
}
|
||||
|
||||
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.mapping = &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: 51820,
|
||||
ExternalPort: 51820,
|
||||
}
|
||||
|
||||
mapping := m.GetMapping()
|
||||
require.NotNil(t, mapping)
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
|
||||
// Mutating the returned copy should not affect the manager's mapping.
|
||||
mapping.ExternalPort = 9999
|
||||
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
|
||||
}
|
||||
|
||||
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.mapping = &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: 51820,
|
||||
ExternalPort: 51820,
|
||||
}
|
||||
|
||||
gateway := newMockNAT()
|
||||
// Seed the mock so we can verify deletion.
|
||||
gateway.mappings[51820] = 51820
|
||||
|
||||
m.cleanup(context.Background(), gateway)
|
||||
|
||||
_, exists := gateway.mappings[51820]
|
||||
assert.False(t, exists, "mapping should be deleted from gateway")
|
||||
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
|
||||
}
|
||||
|
||||
func TestManager_Cleanup_NilMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
gateway := newMockNAT()
|
||||
|
||||
// Should not panic or call gateway.
|
||||
m.cleanup(context.Background(), gateway)
|
||||
}
|
||||
|
||||
|
||||
func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.wgPort = 51820
|
||||
|
||||
gateway := newMockNAT()
|
||||
gateway.onlyPermanentLeases = true
|
||||
|
||||
mapping, err := m.createMapping(context.Background(), gateway)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mapping)
|
||||
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
|
||||
assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
|
||||
}
|
||||
|
||||
func TestIsPermanentLeaseRequired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "UPnP error 725",
|
||||
err: fmt.Errorf("SOAP fault. Code: | Detail: <UPnPError><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped error with 725",
|
||||
err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: <errorCode>725</errorCode>")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error 725 with newlines in XML",
|
||||
err: fmt.Errorf("<errorCode>\n 725\n</errorCode>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "bare 725 without XML tag",
|
||||
err: fmt.Errorf("error code 725"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unrelated error",
|
||||
err: fmt.Errorf("connection refused"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
|
||||
// mgmProber is the subset of management client needed for URL migration probes.
|
||||
type mgmProber interface {
|
||||
GetServerPublicKey() (*wgtypes.Key, error)
|
||||
HealthCheck() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerRDPAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -114,6 +115,7 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerRDPAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -415,6 +417,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerRDPAllowed != nil {
|
||||
if config.ServerRDPAllowed == nil || *input.ServerRDPAllowed != *config.ServerRDPAllowed {
|
||||
if *input.ServerRDPAllowed {
|
||||
log.Infof("enabling RDP passthrough")
|
||||
} else {
|
||||
log.Infof("disabling RDP passthrough")
|
||||
}
|
||||
config.ServerRDPAllowed = input.ServerRDPAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerRDPAllowed == nil {
|
||||
config.ServerRDPAllowed = util.False()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
@@ -777,8 +794,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
if err = client.HealthCheck(); err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -17,12 +17,10 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type mockMgmProber struct {
|
||||
key wgtypes.Key
|
||||
}
|
||||
type mockMgmProber struct{}
|
||||
|
||||
func (m *mockMgmProber) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
return &m.key, nil
|
||||
func (m *mockMgmProber) HealthCheck() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMgmProber) Close() error { return nil }
|
||||
@@ -247,11 +245,7 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
origProber := newMgmProber
|
||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||
key, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mockMgmProber{key: key.PublicKey()}, nil
|
||||
return &mockMgmProber{}, nil
|
||||
}
|
||||
t.Cleanup(func() { newMgmProber = origProber })
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ type Manager interface {
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
GetClientRoutes() route.HAMap
|
||||
GetSelectedClientRoutes() route.HAMap
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
@@ -465,6 +466,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||
return maps.Clone(m.clientRoutes)
|
||||
}
|
||||
|
||||
// GetSelectedClientRoutes returns only the currently selected/active client routes,
|
||||
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
|
||||
// if traffic should be routed through the tunnel.
|
||||
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
m.mux.Lock()
|
||||
|
||||
@@ -18,6 +18,7 @@ type MockManager struct {
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
GetSelectedClientRoutesFunc func() route.HAMap
|
||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
if m.GetClientRoutesFunc != nil {
|
||||
return m.GetClientRoutesFunc()
|
||||
@@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
||||
if m.GetSelectedClientRoutesFunc != nil {
|
||||
return m.GetSelectedClientRoutesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||
|
||||
@@ -53,7 +53,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
n.currentPrefixes = newNets
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
|
||||
@@ -161,7 +161,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
hostDNS := []netip.AddrPort{
|
||||
netip.MustParseAddrPort("9.9.9.9:53"),
|
||||
netip.MustParseAddrPort("149.112.112.112:53"),
|
||||
}
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -472,6 +472,7 @@ type LoginRequest struct {
|
||||
EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||
DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
||||
SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
|
||||
ServerRDPAllowed *bool `protobuf:"varint,40,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -780,6 +781,13 @@ func (x *LoginRequest) GetSshJWTCacheTTL() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *LoginRequest) GetServerRDPAllowed() bool {
|
||||
if x != nil && x.ServerRDPAllowed != nil {
|
||||
return *x.ServerRDPAllowed
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
|
||||
@@ -1312,6 +1320,7 @@ type GetConfigResponse struct {
|
||||
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
|
||||
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
|
||||
ServerRDPAllowed bool `protobuf:"varint,27,opt,name=serverRDPAllowed,proto3" json:"serverRDPAllowed,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1528,6 +1537,13 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetServerRDPAllowed() bool {
|
||||
if x != nil {
|
||||
return x.ServerRDPAllowed
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
type PeerState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
@@ -4139,6 +4155,7 @@ type SetConfigRequest struct {
|
||||
EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
||||
SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
|
||||
ServerRDPAllowed *bool `protobuf:"varint,35,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -4411,6 +4428,13 @@ func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *SetConfigRequest) GetServerRDPAllowed() bool {
|
||||
if x != nil && x.ServerRDPAllowed != nil {
|
||||
return *x.ServerRDPAllowed
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type SetConfigResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
@@ -209,6 +209,8 @@ message LoginRequest {
|
||||
optional bool enableSSHRemotePortForwarding = 37;
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
|
||||
optional bool serverRDPAllowed = 40;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -316,6 +318,8 @@ message GetConfigResponse {
|
||||
bool disableSSHAuth = 25;
|
||||
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool serverRDPAllowed = 27;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -677,6 +681,8 @@ message SetConfigRequest {
|
||||
optional bool enableSSHRemotePortForwarding = 32;
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
|
||||
optional bool serverRDPAllowed = 35;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
|
||||
88
client/rdp/client/client.go
Normal file
88
client/rdp/client/client.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultTimeout is the default timeout for sideband auth requests.
|
||||
DefaultTimeout = 30 * time.Second
|
||||
|
||||
// maxResponseSize is the maximum size of an auth response in bytes.
|
||||
maxResponseSize = 64 * 1024
|
||||
)
|
||||
|
||||
// Client connects to a target peer's RDP sideband auth server to request access.
|
||||
type Client struct {
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// New creates a new sideband RDP auth client.
|
||||
func New() *Client {
|
||||
return &Client{
|
||||
Timeout: DefaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestAuth sends an authorization request to the target peer's sideband server
|
||||
// and returns the response. The addr should be in "host:port" format.
|
||||
func (c *Client) RequestAuth(ctx context.Context, addr string, req *rdpserver.AuthRequest) (*rdpserver.AuthResponse, error) {
|
||||
timeout := c.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = DefaultTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to RDP auth server at %s: %w", addr, err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
deadline, ok := ctx.Deadline()
|
||||
if ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("set connection deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Send request
|
||||
reqData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal auth request: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.Write(reqData); err != nil {
|
||||
return nil, fmt.Errorf("send auth request: %w", err)
|
||||
}
|
||||
|
||||
// Signal we're done writing so the server can read the full request
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
if err := tcpConn.CloseWrite(); err != nil {
|
||||
return nil, fmt.Errorf("close write: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read response
|
||||
respData, err := io.ReadAll(io.LimitReader(conn, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read auth response: %w", err)
|
||||
}
|
||||
|
||||
var resp rdpserver.AuthResponse
|
||||
if err := json.Unmarshal(respData, &resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal auth response: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
31
client/rdp/credprov/Cargo.toml
Normal file
31
client/rdp/credprov/Cargo.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[package]
|
||||
name = "netbird-credprov"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "NetBird RDP Credential Provider for Windows"
|
||||
license = "BSD-3-Clause"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
windows = { version = "0.58", features = [
|
||||
"implement",
|
||||
"Win32_Foundation",
|
||||
"Win32_System_Com",
|
||||
"Win32_UI_Shell",
|
||||
"Win32_Security",
|
||||
"Win32_Security_Authentication_Identity",
|
||||
"Win32_Security_Credentials",
|
||||
"Win32_System_RemoteDesktop",
|
||||
"Win32_System_Threading",
|
||||
] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
log = "0.4"
|
||||
|
||||
[profile.release]
|
||||
opt-level = "s"
|
||||
lto = true
|
||||
strip = true
|
||||
210
client/rdp/credprov/src/credential.rs
Normal file
210
client/rdp/credprov/src/credential.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
//! ICredentialProviderCredential implementation.
|
||||
//!
|
||||
//! Represents a single "NetBird Login" credential tile on the Windows login screen.
|
||||
//! When selected, it queries the local NetBird agent for pending RDP sessions and
|
||||
//! performs S4U logon to authenticate the user without a password.
|
||||
|
||||
use crate::named_pipe_client::{NamedPipeClient, PipeResponse};
|
||||
use crate::s4u;
|
||||
use std::sync::Mutex;
|
||||
use windows::core::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::Win32::Security::Credentials::*;
|
||||
use windows::Win32::UI::Shell::*;
|
||||
|
||||
/// NetBird credential tile that appears on the Windows login screen.
|
||||
#[implement(ICredentialProviderCredential)]
|
||||
pub struct NetBirdCredential {
|
||||
/// The pending session information from the NetBird agent.
|
||||
session: Mutex<Option<PipeResponse>>,
|
||||
/// The remote IP address of the connecting peer.
|
||||
remote_ip: Mutex<String>,
|
||||
}
|
||||
|
||||
impl NetBirdCredential {
|
||||
pub fn new(remote_ip: String, session: PipeResponse) -> Self {
|
||||
Self {
|
||||
session: Mutex::new(Some(session)),
|
||||
remote_ip: Mutex::new(remote_ip),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ICredentialProviderCredential_Impl for NetBirdCredential_Impl {
|
||||
fn Advise(&self, _pcpce: Option<&ICredentialProviderCredentialEvents>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn UnAdvise(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn SetSelected(&self, _pbautologon: *mut BOOL) -> Result<()> {
|
||||
// Auto-logon when this credential is selected
|
||||
unsafe {
|
||||
if !_pbautologon.is_null() {
|
||||
*_pbautologon = TRUE;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn SetDeselected(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetFieldState(
|
||||
&self,
|
||||
_dwfieldid: u32,
|
||||
_pcpfs: *mut CREDENTIAL_PROVIDER_FIELD_STATE,
|
||||
_pcpfis: *mut CREDENTIAL_PROVIDER_FIELD_INTERACTIVE_STATE,
|
||||
) -> Result<()> {
|
||||
// We have a single display-only field showing "NetBird Login"
|
||||
unsafe {
|
||||
if !_pcpfs.is_null() {
|
||||
*_pcpfs = CPFS_DISPLAY_IN_SELECTED_TILE;
|
||||
}
|
||||
if !_pcpfis.is_null() {
|
||||
*_pcpfis = CPFIS_NONE;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetStringValue(&self, _dwfieldid: u32) -> Result<PWSTR> {
|
||||
let session = self.session.lock().unwrap();
|
||||
let text = if let Some(ref s) = *session {
|
||||
format!("NetBird: Logging in as {}", s.os_user)
|
||||
} else {
|
||||
"NetBird Login".to_string()
|
||||
};
|
||||
|
||||
let wide: Vec<u16> = text.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let ptr = unsafe {
|
||||
let mem = windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
|
||||
if mem.is_null() {
|
||||
return Err(E_OUTOFMEMORY.into());
|
||||
}
|
||||
std::ptr::copy_nonoverlapping(wide.as_ptr(), mem, wide.len());
|
||||
PWSTR(mem)
|
||||
};
|
||||
|
||||
Ok(ptr)
|
||||
}
|
||||
|
||||
fn GetBitmapValue(&self, _dwfieldid: u32) -> Result<HBITMAP> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetCheckboxValue(&self, _dwfieldid: u32, _pbchecked: *mut BOOL, _ppszlabel: *mut PWSTR) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetSubmitButtonValue(&self, _dwfieldid: u32, _pdwadjacentto: *mut u32) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetComboBoxValueCount(&self, _dwfieldid: u32, _pcitems: *mut u32, _pdwselecteditem: *mut u32) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetComboBoxValueAt(&self, _dwfieldid: u32, _dwitem: u32) -> Result<PWSTR> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn SetStringValue(&self, _dwfieldid: u32, _psz: &PCWSTR) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn SetCheckboxValue(&self, _dwfieldid: u32, _bchecked: BOOL) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn SetComboBoxSelectedValue(&self, _dwfieldid: u32, _dwselecteditem: u32) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn CommandLinkClicked(&self, _dwfieldid: u32) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetSerialization(
|
||||
&self,
|
||||
_pcpgsr: *mut CREDENTIAL_PROVIDER_GET_SERIALIZATION_RESPONSE,
|
||||
_pcpcs: *mut CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
|
||||
_ppszoptionalstatustext: *mut PWSTR,
|
||||
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
|
||||
) -> Result<()> {
|
||||
let session = self.session.lock().unwrap();
|
||||
let session_info = match &*session {
|
||||
Some(s) => s.clone(),
|
||||
None => {
|
||||
unsafe {
|
||||
*_pcpgsr = CPGSR_NO_CREDENTIAL_NOT_FINISHED;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// Consume the session with the agent
|
||||
if let Err(e) = NamedPipeClient::consume_session(&session_info.session_id) {
|
||||
log::error!("Failed to consume RDP session: {}", e);
|
||||
unsafe {
|
||||
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Perform S4U logon
|
||||
let username = &session_info.os_user;
|
||||
let domain = if session_info.domain.is_empty() {
|
||||
"."
|
||||
} else {
|
||||
&session_info.domain
|
||||
};
|
||||
|
||||
match s4u::generate_s4u_token(username, domain) {
|
||||
Ok(_token) => {
|
||||
// In a full implementation, we would serialize the token into
|
||||
// CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION format
|
||||
// (KerbInteractiveLogon or MsV1_0InteractiveLogon structure).
|
||||
//
|
||||
// For the POC, we signal success. The actual serialization requires
|
||||
// building the proper KERB_INTERACTIVE_LOGON or MSV1_0_INTERACTIVE_LOGON
|
||||
// structure with the token handle, which is complex.
|
||||
//
|
||||
// TODO: Build proper credential serialization from S4U token
|
||||
log::info!(
|
||||
"S4U logon successful for {}\\{}, session {}",
|
||||
domain,
|
||||
username,
|
||||
session_info.session_id
|
||||
);
|
||||
|
||||
unsafe {
|
||||
*_pcpgsr = CPGSR_RETURN_CREDENTIAL_FINISHED;
|
||||
// Note: In production, pcpcs would be filled with the serialized credentials
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("S4U logon failed for {}\\{}: {}", domain, username, e);
|
||||
unsafe {
|
||||
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ReportResult(
|
||||
&self,
|
||||
_ntstatus: NTSTATUS,
|
||||
_ntssubstatus: NTSTATUS,
|
||||
_ppszoptionalstatustext: *mut PWSTR,
|
||||
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
11
client/rdp/credprov/src/guid.rs
Normal file
11
client/rdp/credprov/src/guid.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use windows::core::GUID;
|
||||
|
||||
/// CLSID for the NetBird RDP Credential Provider.
|
||||
/// Generated UUID: {7B3A8E5F-1C4D-4F8A-B2E6-9D0F3A7C5E1B}
|
||||
pub const CLSID_NETBIRD_CREDENTIAL_PROVIDER: GUID = GUID::from_u128(
|
||||
0x7B3A8E5F_1C4D_4F8A_B2E6_9D0F3A7C5E1B,
|
||||
);
|
||||
|
||||
/// Registry path for credential providers.
|
||||
pub const CREDENTIAL_PROVIDER_REGISTRY_PATH: &str =
|
||||
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers";
|
||||
309
client/rdp/credprov/src/lib.rs
Normal file
309
client/rdp/credprov/src/lib.rs
Normal file
@@ -0,0 +1,309 @@
|
||||
//! NetBird RDP Credential Provider for Windows.
|
||||
//!
|
||||
//! This DLL is a Windows Credential Provider that enables passwordless RDP access
|
||||
//! to machines running the NetBird agent. It is loaded by Windows' LogonUI.exe
|
||||
//! via COM when the login screen is displayed.
|
||||
//!
|
||||
//! ## How it works
|
||||
//!
|
||||
//! 1. The DLL is registered as a Credential Provider in the Windows registry
|
||||
//! 2. When an RDP session begins, LogonUI loads the DLL
|
||||
//! 3. The DLL queries the local NetBird agent via named pipe for pending sessions
|
||||
//! 4. If a pending session exists for the connecting peer, the DLL:
|
||||
//! - Shows a "NetBird Login" credential tile
|
||||
//! - Performs S4U logon to create a Windows token without a password
|
||||
//! - Returns the token to LogonUI for session creation
|
||||
|
||||
mod credential;
|
||||
mod guid;
|
||||
mod named_pipe_client;
|
||||
mod provider;
|
||||
mod s4u;
|
||||
|
||||
use guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
|
||||
use provider::NetBirdCredentialProvider;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use windows::core::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::Win32::System::Com::*;
|
||||
|
||||
/// DLL reference count for COM lifecycle management.
|
||||
static DLL_REF_COUNT: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
/// DLL module handle.
|
||||
static mut DLL_MODULE: HMODULE = HMODULE(std::ptr::null_mut());
|
||||
|
||||
/// COM class factory for creating NetBirdCredentialProvider instances.
|
||||
#[implement(IClassFactory)]
|
||||
struct NetBirdClassFactory;
|
||||
|
||||
impl IClassFactory_Impl for NetBirdClassFactory_Impl {
|
||||
fn CreateInstance(
|
||||
&self,
|
||||
_punkouter: Option<&IUnknown>,
|
||||
riid: *const GUID,
|
||||
ppvobject: *mut *mut std::ffi::c_void,
|
||||
) -> Result<()> {
|
||||
unsafe {
|
||||
if !ppvobject.is_null() {
|
||||
*ppvobject = std::ptr::null_mut();
|
||||
}
|
||||
}
|
||||
|
||||
if _punkouter.is_some() {
|
||||
return Err(CLASS_E_NOAGGREGATION.into());
|
||||
}
|
||||
|
||||
let provider = NetBirdCredentialProvider::new();
|
||||
let unknown: IUnknown = provider.into();
|
||||
|
||||
unsafe {
|
||||
unknown.query(riid, ppvobject).ok()
|
||||
}
|
||||
}
|
||||
|
||||
fn LockServer(&self, flock: BOOL) -> Result<()> {
|
||||
if flock.as_bool() {
|
||||
DLL_REF_COUNT.fetch_add(1, Ordering::SeqCst);
|
||||
} else {
|
||||
DLL_REF_COUNT.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// DLL entry point.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllMain(hinstance: HMODULE, reason: u32, _reserved: *mut std::ffi::c_void) -> BOOL {
|
||||
const DLL_PROCESS_ATTACH: u32 = 1;
|
||||
|
||||
if reason == DLL_PROCESS_ATTACH {
|
||||
unsafe {
|
||||
DLL_MODULE = hinstance;
|
||||
}
|
||||
}
|
||||
|
||||
TRUE
|
||||
}
|
||||
|
||||
/// COM entry point: returns a class factory for the requested CLSID.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllGetClassObject(
|
||||
rclsid: *const GUID,
|
||||
riid: *const GUID,
|
||||
ppv: *mut *mut std::ffi::c_void,
|
||||
) -> HRESULT {
|
||||
unsafe {
|
||||
if ppv.is_null() {
|
||||
return E_POINTER;
|
||||
}
|
||||
*ppv = std::ptr::null_mut();
|
||||
|
||||
if *rclsid != CLSID_NETBIRD_CREDENTIAL_PROVIDER {
|
||||
return CLASS_E_CLASSNOTAVAILABLE;
|
||||
}
|
||||
|
||||
let factory = NetBirdClassFactory;
|
||||
let unknown: IUnknown = factory.into();
|
||||
|
||||
match unknown.query(riid, ppv) {
|
||||
Ok(()) => S_OK,
|
||||
Err(e) => e.code(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// COM entry point: indicates whether the DLL can be unloaded.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllCanUnloadNow() -> HRESULT {
|
||||
if DLL_REF_COUNT.load(Ordering::SeqCst) == 0 {
|
||||
S_OK
|
||||
} else {
|
||||
S_FALSE
|
||||
}
|
||||
}
|
||||
|
||||
/// Self-registration: called by regsvr32 to register the credential provider.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllRegisterServer() -> HRESULT {
|
||||
match register_credential_provider(true) {
|
||||
Ok(()) => S_OK,
|
||||
Err(_) => E_FAIL,
|
||||
}
|
||||
}
|
||||
|
||||
/// Self-unregistration: called by regsvr32 /u to unregister the credential provider.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllUnregisterServer() -> HRESULT {
|
||||
match register_credential_provider(false) {
|
||||
Ok(()) => S_OK,
|
||||
Err(_) => E_FAIL,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_credential_provider(register: bool) -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
use windows::Win32::System::Registry::*;
|
||||
|
||||
let clsid_str = format!("{{{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}",
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data1,
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data2,
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data3,
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[0],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[1],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[2],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[3],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[4],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[5],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[6],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[7],
|
||||
);
|
||||
|
||||
if register {
|
||||
// Register under Credential Providers
|
||||
let cp_key_path = format!(
|
||||
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
|
||||
clsid_str
|
||||
);
|
||||
|
||||
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let mut hkey = HKEY::default();
|
||||
|
||||
unsafe {
|
||||
let result = RegCreateKeyExW(
|
||||
HKEY_LOCAL_MACHINE,
|
||||
PCWSTR(cp_key_wide.as_ptr()),
|
||||
0,
|
||||
PCWSTR::null(),
|
||||
REG_OPTION_NON_VOLATILE,
|
||||
KEY_WRITE,
|
||||
None,
|
||||
&mut hkey,
|
||||
None,
|
||||
);
|
||||
if result.is_err() {
|
||||
return Err("Failed to create credential provider registry key".into());
|
||||
}
|
||||
|
||||
let value: Vec<u16> = "NetBird RDP Credential Provider"
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect();
|
||||
let _ = RegSetValueExW(
|
||||
hkey,
|
||||
PCWSTR::null(),
|
||||
0,
|
||||
REG_SZ,
|
||||
Some(std::slice::from_raw_parts(
|
||||
value.as_ptr() as *const u8,
|
||||
value.len() * 2,
|
||||
)),
|
||||
);
|
||||
let _ = RegCloseKey(hkey);
|
||||
}
|
||||
|
||||
// Register CLSID in CLSID hive
|
||||
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
|
||||
let clsid_key_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
unsafe {
|
||||
let result = RegCreateKeyExW(
|
||||
HKEY_CLASSES_ROOT,
|
||||
PCWSTR(clsid_key_wide.as_ptr()),
|
||||
0,
|
||||
PCWSTR::null(),
|
||||
REG_OPTION_NON_VOLATILE,
|
||||
KEY_WRITE,
|
||||
None,
|
||||
&mut hkey,
|
||||
None,
|
||||
);
|
||||
if result.is_err() {
|
||||
return Err("Failed to create CLSID registry key".into());
|
||||
}
|
||||
let _ = RegCloseKey(hkey);
|
||||
|
||||
// InprocServer32 subkey
|
||||
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
|
||||
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
let result = RegCreateKeyExW(
|
||||
HKEY_CLASSES_ROOT,
|
||||
PCWSTR(inproc_wide.as_ptr()),
|
||||
0,
|
||||
PCWSTR::null(),
|
||||
REG_OPTION_NON_VOLATILE,
|
||||
KEY_WRITE,
|
||||
None,
|
||||
&mut hkey,
|
||||
None,
|
||||
);
|
||||
if result.is_err() {
|
||||
return Err("Failed to create InprocServer32 registry key".into());
|
||||
}
|
||||
|
||||
// Set DLL path
|
||||
let mut dll_path = [0u16; 260];
|
||||
let len = windows::Win32::System::LibraryLoader::GetModuleFileNameW(
|
||||
DLL_MODULE,
|
||||
&mut dll_path,
|
||||
);
|
||||
if len > 0 {
|
||||
let _ = RegSetValueExW(
|
||||
hkey,
|
||||
PCWSTR::null(),
|
||||
0,
|
||||
REG_SZ,
|
||||
Some(std::slice::from_raw_parts(
|
||||
dll_path.as_ptr() as *const u8,
|
||||
(len as usize + 1) * 2,
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
// Set threading model
|
||||
let threading: Vec<u16> = "Apartment"
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect();
|
||||
let threading_name: Vec<u16> = "ThreadingModel"
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect();
|
||||
let _ = RegSetValueExW(
|
||||
hkey,
|
||||
PCWSTR(threading_name.as_ptr()),
|
||||
0,
|
||||
REG_SZ,
|
||||
Some(std::slice::from_raw_parts(
|
||||
threading.as_ptr() as *const u8,
|
||||
threading.len() * 2,
|
||||
)),
|
||||
);
|
||||
|
||||
let _ = RegCloseKey(hkey);
|
||||
}
|
||||
} else {
|
||||
// Unregister
|
||||
let cp_key_path = format!(
|
||||
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
|
||||
clsid_str
|
||||
);
|
||||
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
unsafe {
|
||||
let _ = RegDeleteKeyW(HKEY_LOCAL_MACHINE, PCWSTR(cp_key_wide.as_ptr()));
|
||||
}
|
||||
|
||||
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
|
||||
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
|
||||
let clsid_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
unsafe {
|
||||
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(inproc_wide.as_ptr()));
|
||||
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(clsid_wide.as_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
135
client/rdp/credprov/src/named_pipe_client.rs
Normal file
135
client/rdp/credprov/src/named_pipe_client.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::{Read, Write};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Named pipe path for communicating with the NetBird agent.
|
||||
const PIPE_NAME: &str = r"\\.\pipe\netbird-rdp-auth";
|
||||
|
||||
/// Maximum response size from the agent.
|
||||
const MAX_RESPONSE_SIZE: usize = 4096;
|
||||
|
||||
/// Timeout for named pipe operations.
|
||||
const PIPE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Request sent to the NetBird agent via named pipe.
|
||||
#[derive(Serialize)]
|
||||
pub struct PipeRequest {
|
||||
pub action: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub remote_ip: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Response received from the NetBird agent via named pipe.
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct PipeResponse {
|
||||
pub found: bool,
|
||||
#[serde(default)]
|
||||
pub session_id: String,
|
||||
#[serde(default)]
|
||||
pub os_user: String,
|
||||
#[serde(default)]
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
/// Client for communicating with the NetBird agent's named pipe server.
|
||||
pub struct NamedPipeClient;
|
||||
|
||||
impl NamedPipeClient {
|
||||
/// Query the NetBird agent for a pending RDP session matching the given remote IP.
|
||||
pub fn query_pending(remote_ip: &str) -> Result<PipeResponse, PipeError> {
|
||||
let request = PipeRequest {
|
||||
action: "query_pending".to_string(),
|
||||
remote_ip: Some(remote_ip.to_string()),
|
||||
session_id: None,
|
||||
};
|
||||
Self::send_request(&request)
|
||||
}
|
||||
|
||||
/// Tell the NetBird agent to consume (mark as used) a pending session.
|
||||
pub fn consume_session(session_id: &str) -> Result<PipeResponse, PipeError> {
|
||||
let request = PipeRequest {
|
||||
action: "consume".to_string(),
|
||||
remote_ip: None,
|
||||
session_id: Some(session_id.to_string()),
|
||||
};
|
||||
Self::send_request(&request)
|
||||
}
|
||||
|
||||
fn send_request(request: &PipeRequest) -> Result<PipeResponse, PipeError> {
|
||||
let request_data =
|
||||
serde_json::to_vec(request).map_err(|e| PipeError::Serialization(e.to_string()))?;
|
||||
|
||||
// Open named pipe (CreateFile in Windows)
|
||||
let mut pipe = Self::open_pipe()?;
|
||||
|
||||
// Write request
|
||||
pipe.write_all(&request_data)
|
||||
.map_err(|e| PipeError::Write(e.to_string()))?;
|
||||
|
||||
// Shutdown write side to signal end of request
|
||||
// For named pipes on Windows, we rely on the message boundary
|
||||
pipe.flush()
|
||||
.map_err(|e| PipeError::Write(e.to_string()))?;
|
||||
|
||||
// Read response
|
||||
let mut response_data = vec![0u8; MAX_RESPONSE_SIZE];
|
||||
let n = pipe
|
||||
.read(&mut response_data)
|
||||
.map_err(|e| PipeError::Read(e.to_string()))?;
|
||||
|
||||
let response: PipeResponse = serde_json::from_slice(&response_data[..n])
|
||||
.map_err(|e| PipeError::Deserialization(e.to_string()))?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn open_pipe() -> Result<std::fs::File, PipeError> {
|
||||
// On Windows, named pipes are opened like files
|
||||
use std::fs::OpenOptions;
|
||||
|
||||
// Try to open the pipe with a brief retry for PIPE_BUSY
|
||||
for attempt in 0..3 {
|
||||
match OpenOptions::new().read(true).write(true).open(PIPE_NAME) {
|
||||
Ok(file) => return Ok(file),
|
||||
Err(e) => {
|
||||
if attempt < 2 {
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
return Err(PipeError::Connect(format!(
|
||||
"failed to open pipe {}: {}",
|
||||
PIPE_NAME, e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(PipeError::Connect("exhausted pipe connection attempts".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that can occur during named pipe communication.
|
||||
#[derive(Debug)]
|
||||
pub enum PipeError {
|
||||
Connect(String),
|
||||
Write(String),
|
||||
Read(String),
|
||||
Serialization(String),
|
||||
Deserialization(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PipeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PipeError::Connect(e) => write!(f, "pipe connect: {}", e),
|
||||
PipeError::Write(e) => write!(f, "pipe write: {}", e),
|
||||
PipeError::Read(e) => write!(f, "pipe read: {}", e),
|
||||
PipeError::Serialization(e) => write!(f, "pipe serialization: {}", e),
|
||||
PipeError::Deserialization(e) => write!(f, "pipe deserialization: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PipeError {}
|
||||
270
client/rdp/credprov/src/provider.rs
Normal file
270
client/rdp/credprov/src/provider.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
//! ICredentialProvider implementation.
|
||||
//!
|
||||
//! This is the main COM object that Windows' LogonUI.exe instantiates.
|
||||
//! It determines whether to show a "NetBird Login" credential tile based on
|
||||
//! whether the NetBird agent has a pending RDP session for the connecting peer.
|
||||
|
||||
use crate::credential::NetBirdCredential;
|
||||
use crate::guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
|
||||
use crate::named_pipe_client::NamedPipeClient;
|
||||
use std::sync::Mutex;
|
||||
use windows::core::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::Win32::Security::Credentials::*;
|
||||
use windows::Win32::System::RemoteDesktop::*;
|
||||
|
||||
/// The NetBird Credential Provider, loaded by LogonUI.exe via COM.
|
||||
#[implement(ICredentialProvider)]
|
||||
pub struct NetBirdCredentialProvider {
|
||||
/// The credential tile (if a pending session was found).
|
||||
credential: Mutex<Option<ICredentialProviderCredential>>,
|
||||
/// Whether this provider is active for the current usage scenario.
|
||||
active: Mutex<bool>,
|
||||
}
|
||||
|
||||
impl NetBirdCredentialProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
credential: Mutex::new(None),
|
||||
active: Mutex::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ICredentialProvider_Impl for NetBirdCredentialProvider_Impl {
|
||||
fn SetUsageScenario(
|
||||
&self,
|
||||
cpus: CREDENTIAL_PROVIDER_USAGE_SCENARIO,
|
||||
_dwflags: u32,
|
||||
) -> Result<()> {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
|
||||
match cpus {
|
||||
CPUS_LOGON | CPUS_UNLOCK_WORKSTATION => {
|
||||
// We activate for RDP logon and unlock scenarios
|
||||
*active = true;
|
||||
log::info!("NetBird CP activated for usage scenario {:?}", cpus.0);
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
// Don't activate for credui or other scenarios
|
||||
*active = false;
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn SetSerialization(
|
||||
&self,
|
||||
_pcpcs: *const CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
|
||||
) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn Advise(
|
||||
&self,
|
||||
_pcpe: Option<&ICredentialProviderEvents>,
|
||||
_upadvisecontext: usize,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn UnAdvise(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetFieldDescriptorCount(&self) -> Result<u32> {
|
||||
// We have one field: a large text label showing "NetBird: Logging in as <user>"
|
||||
Ok(1)
|
||||
}
|
||||
|
||||
fn GetFieldDescriptorAt(
|
||||
&self,
|
||||
_dwindex: u32,
|
||||
_ppcpfd: *mut *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR,
|
||||
) -> Result<()> {
|
||||
if _dwindex != 0 {
|
||||
return Err(E_INVALIDARG.into());
|
||||
}
|
||||
|
||||
let label = "NetBird Login";
|
||||
let wide: Vec<u16> = label.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
unsafe {
|
||||
let desc = windows::Win32::System::Com::CoTaskMemAlloc(
|
||||
std::mem::size_of::<CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR>(),
|
||||
) as *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR;
|
||||
|
||||
if desc.is_null() {
|
||||
return Err(E_OUTOFMEMORY.into());
|
||||
}
|
||||
|
||||
let label_mem =
|
||||
windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
|
||||
if label_mem.is_null() {
|
||||
windows::Win32::System::Com::CoTaskMemFree(Some(desc as *const _));
|
||||
return Err(E_OUTOFMEMORY.into());
|
||||
}
|
||||
std::ptr::copy_nonoverlapping(wide.as_ptr(), label_mem, wide.len());
|
||||
|
||||
(*desc).dwFieldID = 0;
|
||||
(*desc).cpft = CPFT_LARGE_TEXT;
|
||||
(*desc).pszLabel = PWSTR(label_mem);
|
||||
(*desc).guidFieldType = GUID::zeroed();
|
||||
|
||||
*_ppcpfd = desc;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetCredentialCount(
|
||||
&self,
|
||||
_pdwcount: *mut u32,
|
||||
_pdwdefault: *mut u32,
|
||||
_pbautologinwithdefault: *mut BOOL,
|
||||
) -> Result<()> {
|
||||
let active = self.active.lock().unwrap();
|
||||
if !*active {
|
||||
unsafe {
|
||||
*_pdwcount = 0;
|
||||
*_pdwdefault = u32::MAX;
|
||||
*_pbautologinwithdefault = FALSE;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Try to get the client IP of the current RDP session
|
||||
let remote_ip = match get_rdp_client_ip() {
|
||||
Some(ip) => ip,
|
||||
None => {
|
||||
log::debug!("NetBird CP: could not determine RDP client IP");
|
||||
unsafe {
|
||||
*_pdwcount = 0;
|
||||
*_pdwdefault = u32::MAX;
|
||||
*_pbautologinwithdefault = FALSE;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// Query the NetBird agent for a pending session
|
||||
match NamedPipeClient::query_pending(&remote_ip) {
|
||||
Ok(response) if response.found => {
|
||||
log::info!(
|
||||
"NetBird CP: found pending session for {} -> {}",
|
||||
remote_ip,
|
||||
response.os_user
|
||||
);
|
||||
|
||||
let cred = NetBirdCredential::new(remote_ip, response);
|
||||
let icred: ICredentialProviderCredential = cred.into();
|
||||
|
||||
let mut credential = self.credential.lock().unwrap();
|
||||
*credential = Some(icred);
|
||||
|
||||
unsafe {
|
||||
*_pdwcount = 1;
|
||||
*_pdwdefault = 0;
|
||||
*_pbautologinwithdefault = TRUE; // auto-logon
|
||||
}
|
||||
}
|
||||
Ok(_) => {
|
||||
log::debug!("NetBird CP: no pending session for {}", remote_ip);
|
||||
unsafe {
|
||||
*_pdwcount = 0;
|
||||
*_pdwdefault = u32::MAX;
|
||||
*_pbautologinwithdefault = FALSE;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::debug!("NetBird CP: pipe query failed: {}", e);
|
||||
unsafe {
|
||||
*_pdwcount = 0;
|
||||
*_pdwdefault = u32::MAX;
|
||||
*_pbautologinwithdefault = FALSE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetCredentialAt(
|
||||
&self,
|
||||
_dwindex: u32,
|
||||
_ppcpc: *mut Option<ICredentialProviderCredential>,
|
||||
) -> Result<()> {
|
||||
if _dwindex != 0 {
|
||||
return Err(E_INVALIDARG.into());
|
||||
}
|
||||
|
||||
let credential = self.credential.lock().unwrap();
|
||||
match &*credential {
|
||||
Some(cred) => {
|
||||
unsafe {
|
||||
*_ppcpc = Some(cred.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
None => Err(E_UNEXPECTED.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the IP address of the remote RDP client for the current session.
|
||||
fn get_rdp_client_ip() -> Option<String> {
|
||||
unsafe {
|
||||
// Get the current session ID
|
||||
let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
|
||||
let mut session_id = 0u32;
|
||||
|
||||
if !windows::Win32::System::RemoteDesktop::ProcessIdToSessionId(process_id, &mut session_id)
|
||||
.as_bool()
|
||||
{
|
||||
log::debug!("ProcessIdToSessionId failed");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Query the client address
|
||||
let mut buffer: *mut WTS_CLIENT_ADDRESS = std::ptr::null_mut();
|
||||
let mut bytes_returned = 0u32;
|
||||
|
||||
let result = WTSQuerySessionInformationW(
|
||||
WTS_CURRENT_SERVER_HANDLE,
|
||||
session_id,
|
||||
WTS_INFO_CLASS(14), // WTSClientAddress
|
||||
&mut buffer as *mut _ as *mut *mut u16,
|
||||
&mut bytes_returned,
|
||||
);
|
||||
|
||||
if !result.as_bool() || buffer.is_null() {
|
||||
log::debug!("WTSQuerySessionInformation(WTSClientAddress) failed");
|
||||
return None;
|
||||
}
|
||||
|
||||
let client_addr = &*buffer;
|
||||
let ip = match client_addr.AddressFamily as u32 {
|
||||
// AF_INET
|
||||
2 => {
|
||||
let addr = &client_addr.Address;
|
||||
Some(format!("{}.{}.{}.{}", addr[2], addr[3], addr[4], addr[5]))
|
||||
}
|
||||
// AF_INET6
|
||||
23 => {
|
||||
// IPv6 - extract from Address bytes
|
||||
let addr = &client_addr.Address;
|
||||
Some(format!(
|
||||
"{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
|
||||
addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9],
|
||||
addr[10], addr[11], addr[12], addr[13], addr[14], addr[15], addr[16], addr[17]
|
||||
))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
WTSFreeMemory(buffer as *mut std::ffi::c_void);
|
||||
|
||||
ip
|
||||
}
|
||||
}
|
||||
398
client/rdp/credprov/src/s4u.rs
Normal file
398
client/rdp/credprov/src/s4u.rs
Normal file
@@ -0,0 +1,398 @@
|
||||
//! S4U (Service for User) authentication for Windows.
|
||||
//!
|
||||
//! This module ports the S4U logon logic from the Go implementation at:
|
||||
//! `client/ssh/server/executor_windows.go:generateS4UUserToken()`
|
||||
//!
|
||||
//! It creates Windows logon tokens without requiring a password, using the LSA
|
||||
//! (Local Security Authority) S4U mechanism. This is the same approach used by
|
||||
//! OpenSSH for Windows for public key authentication.
|
||||
|
||||
use std::ptr;
|
||||
use windows::core::{PCSTR, PWSTR};
|
||||
use windows::Win32::Foundation::{HANDLE, LUID, NTSTATUS, PSID};
|
||||
use windows::Win32::Security::Authentication::Identity::{
|
||||
LsaDeregisterLogonProcess, LsaFreeReturnBuffer, LsaLogonUser, LsaLookupAuthenticationPackage,
|
||||
LsaRegisterLogonProcess, KERB_S4U_LOGON, MSV1_0_S4U_LOGON, MSV1_0_S4U_LOGON_FLAG_CHECK_LOGONHOURS,
|
||||
SECURITY_LOGON_TYPE,
|
||||
};
|
||||
use windows::Win32::Security::{
|
||||
QUOTA_LIMITS, TOKEN_SOURCE,
|
||||
};
|
||||
|
||||
/// Status code for successful LSA operations.
|
||||
const STATUS_SUCCESS: i32 = 0;
|
||||
|
||||
/// Network logon type (used for S4U).
|
||||
const LOGON32_LOGON_NETWORK: SECURITY_LOGON_TYPE = SECURITY_LOGON_TYPE(3);
|
||||
|
||||
/// Kerberos S4U logon message type.
|
||||
const KERB_S4U_LOGON_TYPE: u32 = 12;
|
||||
|
||||
/// MSV1_0 S4U logon message type.
|
||||
const MSV1_0_S4U_LOGON_TYPE: u32 = 12;
|
||||
|
||||
/// Authentication package name for Kerberos.
|
||||
const KERBEROS_PACKAGE: &str = "Kerberos";
|
||||
|
||||
/// Authentication package name for MSV1_0 (local users).
|
||||
const MSV1_0_PACKAGE: &str = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0";
|
||||
|
||||
/// Result of a successful S4U logon.
|
||||
pub struct S4UToken {
|
||||
pub handle: HANDLE,
|
||||
}
|
||||
|
||||
impl Drop for S4UToken {
|
||||
fn drop(&mut self) {
|
||||
if !self.handle.is_invalid() {
|
||||
unsafe {
|
||||
let _ = windows::Win32::Foundation::CloseHandle(self.handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors from S4U logon operations.
|
||||
#[derive(Debug)]
|
||||
pub enum S4UError {
|
||||
LsaRegister(NTSTATUS),
|
||||
LookupPackage(NTSTATUS),
|
||||
LogonUser(NTSTATUS, i32),
|
||||
AllocateLuid,
|
||||
InvalidUsername(String),
|
||||
Utf16Conversion(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for S4UError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
S4UError::LsaRegister(s) => write!(f, "LsaRegisterLogonProcess: 0x{:x}", s.0),
|
||||
S4UError::LookupPackage(s) => write!(f, "LsaLookupAuthenticationPackage: 0x{:x}", s.0),
|
||||
S4UError::LogonUser(s, sub) => {
|
||||
write!(f, "LsaLogonUser S4U: NTSTATUS=0x{:x}, SubStatus=0x{:x}", s.0, sub)
|
||||
}
|
||||
S4UError::AllocateLuid => write!(f, "AllocateLocallyUniqueId failed"),
|
||||
S4UError::InvalidUsername(u) => write!(f, "invalid username: {}", u),
|
||||
S4UError::Utf16Conversion(s) => write!(f, "UTF-16 conversion: {}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for S4UError {}
|
||||
|
||||
/// Generate a Windows logon token using S4U authentication.
|
||||
///
|
||||
/// This creates a token for the specified user without requiring a password.
|
||||
/// The calling process must have SeTcbPrivilege (typically SYSTEM).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `username` - The Windows username (without domain prefix)
|
||||
/// * `domain` - The domain name ("." for local users)
|
||||
///
|
||||
/// # Returns
|
||||
/// An `S4UToken` containing the Windows logon token handle.
|
||||
pub fn generate_s4u_token(username: &str, domain: &str) -> Result<S4UToken, S4UError> {
|
||||
if username.is_empty() {
|
||||
return Err(S4UError::InvalidUsername("empty username".to_string()));
|
||||
}
|
||||
|
||||
let is_local = is_local_user(domain);
|
||||
|
||||
// Initialize LSA connection
|
||||
let lsa_handle = initialize_lsa_connection()?;
|
||||
|
||||
// Lookup authentication package
|
||||
let auth_package_id = lookup_auth_package(lsa_handle, is_local)?;
|
||||
|
||||
// Perform S4U logon
|
||||
let result = perform_s4u_logon(lsa_handle, auth_package_id, username, domain, is_local);
|
||||
|
||||
// Cleanup LSA connection
|
||||
unsafe {
|
||||
let _ = LsaDeregisterLogonProcess(lsa_handle);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn is_local_user(domain: &str) -> bool {
|
||||
domain.is_empty() || domain == "."
|
||||
}
|
||||
|
||||
fn initialize_lsa_connection() -> Result<HANDLE, S4UError> {
|
||||
let process_name = "NetBird\0";
|
||||
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||
Length: (process_name.len() - 1) as u16,
|
||||
MaximumLength: process_name.len() as u16,
|
||||
Buffer: windows::core::PSTR(process_name.as_ptr() as *mut u8),
|
||||
};
|
||||
|
||||
let mut lsa_handle = HANDLE::default();
|
||||
let mut mode = 0u32;
|
||||
|
||||
let status = unsafe {
|
||||
LsaRegisterLogonProcess(&mut lsa_string, &mut lsa_handle, &mut mode)
|
||||
};
|
||||
|
||||
if status.0 != STATUS_SUCCESS {
|
||||
return Err(S4UError::LsaRegister(status));
|
||||
}
|
||||
|
||||
Ok(lsa_handle)
|
||||
}
|
||||
|
||||
fn lookup_auth_package(lsa_handle: HANDLE, is_local: bool) -> Result<u32, S4UError> {
|
||||
let package_name = if is_local { MSV1_0_PACKAGE } else { KERBEROS_PACKAGE };
|
||||
let package_with_null = format!("{}\0", package_name);
|
||||
|
||||
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||
Length: (package_with_null.len() - 1) as u16,
|
||||
MaximumLength: package_with_null.len() as u16,
|
||||
Buffer: windows::core::PSTR(package_with_null.as_ptr() as *mut u8),
|
||||
};
|
||||
|
||||
let mut auth_package_id = 0u32;
|
||||
let status = unsafe {
|
||||
LsaLookupAuthenticationPackage(lsa_handle, &mut lsa_string, &mut auth_package_id)
|
||||
};
|
||||
|
||||
if status.0 != STATUS_SUCCESS {
|
||||
return Err(S4UError::LookupPackage(status));
|
||||
}
|
||||
|
||||
Ok(auth_package_id)
|
||||
}
|
||||
|
||||
fn perform_s4u_logon(
|
||||
lsa_handle: HANDLE,
|
||||
auth_package_id: u32,
|
||||
username: &str,
|
||||
domain: &str,
|
||||
is_local: bool,
|
||||
) -> Result<S4UToken, S4UError> {
|
||||
// Prepare token source
|
||||
let mut source_name = [0u8; 8];
|
||||
let name_bytes = b"netbird";
|
||||
source_name[..name_bytes.len()].copy_from_slice(name_bytes);
|
||||
|
||||
let mut source_id = LUID::default();
|
||||
let alloc_ok = unsafe {
|
||||
windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime(
|
||||
&mut std::mem::zeroed(),
|
||||
);
|
||||
// Use a simpler approach - just use the current time as a unique ID
|
||||
source_id.LowPart = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.subsec_nanos();
|
||||
source_id.HighPart = std::process::id() as i32;
|
||||
true
|
||||
};
|
||||
|
||||
if !alloc_ok {
|
||||
return Err(S4UError::AllocateLuid);
|
||||
}
|
||||
|
||||
let token_source = TOKEN_SOURCE {
|
||||
SourceName: source_name,
|
||||
SourceIdentifier: source_id,
|
||||
};
|
||||
|
||||
let origin_name_str = "netbird\0";
|
||||
let mut origin_name = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||
Length: (origin_name_str.len() - 1) as u16,
|
||||
MaximumLength: origin_name_str.len() as u16,
|
||||
Buffer: windows::core::PSTR(origin_name_str.as_ptr() as *mut u8),
|
||||
};
|
||||
|
||||
// Build the logon info structure
|
||||
let (logon_info_ptr, logon_info_size) = if is_local {
|
||||
build_msv1_0_s4u_logon(username)?
|
||||
} else {
|
||||
build_kerb_s4u_logon(username, domain)?
|
||||
};
|
||||
|
||||
let mut profile: *mut std::ffi::c_void = ptr::null_mut();
|
||||
let mut profile_size = 0u32;
|
||||
let mut logon_id = LUID::default();
|
||||
let mut token = HANDLE::default();
|
||||
let mut quotas = QUOTA_LIMITS::default();
|
||||
let mut sub_status: i32 = 0;
|
||||
|
||||
let status = unsafe {
|
||||
LsaLogonUser(
|
||||
lsa_handle,
|
||||
&mut origin_name,
|
||||
LOGON32_LOGON_NETWORK,
|
||||
auth_package_id,
|
||||
logon_info_ptr as *const std::ffi::c_void,
|
||||
logon_info_size as u32,
|
||||
None, // local groups
|
||||
&token_source,
|
||||
&mut profile,
|
||||
&mut profile_size,
|
||||
&mut logon_id,
|
||||
&mut token,
|
||||
&mut quotas,
|
||||
&mut sub_status,
|
||||
)
|
||||
};
|
||||
|
||||
// Free profile buffer if allocated
|
||||
if !profile.is_null() {
|
||||
unsafe {
|
||||
let _ = LsaFreeReturnBuffer(profile);
|
||||
}
|
||||
}
|
||||
|
||||
// Free the logon info buffer
|
||||
unsafe {
|
||||
let layout = std::alloc::Layout::from_size_align_unchecked(logon_info_size, 8);
|
||||
std::alloc::dealloc(logon_info_ptr as *mut u8, layout);
|
||||
}
|
||||
|
||||
if status.0 != STATUS_SUCCESS {
|
||||
return Err(S4UError::LogonUser(status, sub_status));
|
||||
}
|
||||
|
||||
Ok(S4UToken { handle: token })
|
||||
}
|
||||
|
||||
/// Build MSV1_0_S4U_LOGON structure for local users.
|
||||
fn build_msv1_0_s4u_logon(username: &str) -> Result<(*mut u8, usize), S4UError> {
|
||||
let username_utf16: Vec<u16> = username.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let domain_utf16: Vec<u16> = ".".encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
let username_byte_size = username_utf16.len() * 2;
|
||||
let domain_byte_size = domain_utf16.len() * 2;
|
||||
|
||||
// MSV1_0_S4U_LOGON structure:
|
||||
// MessageType: u32 (4 bytes)
|
||||
// Flags: u32 (4 bytes)
|
||||
// UserPrincipalName: UNICODE_STRING (8 bytes on 32-bit, 16 bytes on 64-bit)
|
||||
// DomainName: UNICODE_STRING
|
||||
let struct_size = std::mem::size_of::<MSV1_0_S4U_LOGON_HEADER>();
|
||||
let total_size = struct_size + username_byte_size + domain_byte_size;
|
||||
|
||||
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
|
||||
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
|
||||
|
||||
if buffer.is_null() {
|
||||
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
|
||||
}
|
||||
|
||||
// For the POC, we'll set up the raw bytes manually since the windows-rs
|
||||
// MSV1_0_S4U_LOGON structure layout may differ.
|
||||
// This is a simplified version - in production, use proper FFI bindings.
|
||||
|
||||
unsafe {
|
||||
// MessageType = MSV1_0_S4U_LOGON_TYPE (12)
|
||||
*(buffer as *mut u32) = MSV1_0_S4U_LOGON_TYPE;
|
||||
// Flags = 0
|
||||
*((buffer as *mut u32).add(1)) = 0;
|
||||
|
||||
// Copy username UTF-16 after the structure
|
||||
let username_offset = struct_size;
|
||||
let username_dest = buffer.add(username_offset);
|
||||
ptr::copy_nonoverlapping(
|
||||
username_utf16.as_ptr() as *const u8,
|
||||
username_dest,
|
||||
username_byte_size,
|
||||
);
|
||||
|
||||
// Copy domain UTF-16 after username
|
||||
let domain_offset = username_offset + username_byte_size;
|
||||
let domain_dest = buffer.add(domain_offset);
|
||||
ptr::copy_nonoverlapping(
|
||||
domain_utf16.as_ptr() as *const u8,
|
||||
domain_dest,
|
||||
domain_byte_size,
|
||||
);
|
||||
|
||||
// Set UNICODE_STRING for UserPrincipalName (offset 8 on 64-bit)
|
||||
// Length, MaximumLength, Buffer pointer
|
||||
let upn_ptr = buffer.add(8) as *mut u16;
|
||||
*upn_ptr = ((username_utf16.len() - 1) * 2) as u16; // Length (without null)
|
||||
*(upn_ptr.add(1)) = (username_utf16.len() * 2) as u16; // MaximumLength
|
||||
*((buffer.add(8 + 4)) as *mut *const u8) = username_dest; // Buffer
|
||||
|
||||
// Set UNICODE_STRING for DomainName
|
||||
let dn_offset = 8 + std::mem::size_of::<UnicodeStringRaw>();
|
||||
let dn_ptr = buffer.add(dn_offset) as *mut u16;
|
||||
*dn_ptr = ((domain_utf16.len() - 1) * 2) as u16;
|
||||
*(dn_ptr.add(1)) = (domain_utf16.len() * 2) as u16;
|
||||
*((buffer.add(dn_offset + 4)) as *mut *const u8) = domain_dest;
|
||||
}
|
||||
|
||||
Ok((buffer, total_size))
|
||||
}
|
||||
|
||||
/// Build KERB_S4U_LOGON structure for domain users.
|
||||
fn build_kerb_s4u_logon(username: &str, domain: &str) -> Result<(*mut u8, usize), S4UError> {
|
||||
// Build UPN: username@domain
|
||||
let upn = format!("{}@{}", username, domain);
|
||||
let upn_utf16: Vec<u16> = upn.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let upn_byte_size = upn_utf16.len() * 2;
|
||||
|
||||
let struct_size = std::mem::size_of::<KerbS4ULogonHeader>();
|
||||
let total_size = struct_size + upn_byte_size;
|
||||
|
||||
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
|
||||
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
|
||||
|
||||
if buffer.is_null() {
|
||||
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// MessageType = KERB_S4U_LOGON_TYPE (12)
|
||||
*(buffer as *mut u32) = KERB_S4U_LOGON_TYPE;
|
||||
// Flags = 0
|
||||
*((buffer as *mut u32).add(1)) = 0;
|
||||
|
||||
// Copy UPN UTF-16 after the structure
|
||||
let upn_offset = struct_size;
|
||||
let upn_dest = buffer.add(upn_offset);
|
||||
ptr::copy_nonoverlapping(
|
||||
upn_utf16.as_ptr() as *const u8,
|
||||
upn_dest,
|
||||
upn_byte_size,
|
||||
);
|
||||
|
||||
// Set UNICODE_STRING for ClientUpn (offset 8)
|
||||
let upn_str_ptr = buffer.add(8) as *mut u16;
|
||||
*upn_str_ptr = ((upn_utf16.len() - 1) * 2) as u16;
|
||||
*(upn_str_ptr.add(1)) = (upn_utf16.len() * 2) as u16;
|
||||
*((buffer.add(8 + 4)) as *mut *const u8) = upn_dest;
|
||||
|
||||
// ClientRealm is empty (zeroed)
|
||||
}
|
||||
|
||||
Ok((buffer, total_size))
|
||||
}
|
||||
|
||||
/// Raw UNICODE_STRING layout for size calculation.
|
||||
#[repr(C)]
|
||||
struct UnicodeStringRaw {
|
||||
_length: u16,
|
||||
_maximum_length: u16,
|
||||
_buffer: *const u16,
|
||||
}
|
||||
|
||||
/// Header size for MSV1_0_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
|
||||
#[repr(C)]
|
||||
struct MSV1_0_S4U_LOGON_HEADER {
|
||||
_message_type: u32,
|
||||
_flags: u32,
|
||||
_user_principal_name: UnicodeStringRaw,
|
||||
_domain_name: UnicodeStringRaw,
|
||||
}
|
||||
|
||||
/// Header size for KERB_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
|
||||
#[repr(C)]
|
||||
struct KerbS4ULogonHeader {
|
||||
_message_type: u32,
|
||||
_flags: u32,
|
||||
_client_upn: UnicodeStringRaw,
|
||||
_client_realm: UnicodeStringRaw,
|
||||
}
|
||||
21
client/rdp/server/addr.go
Normal file
21
client/rdp/server/addr.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// parseAddr parses a string into a netip.Addr, stripping any port or zone.
|
||||
func parseAddr(s string) (netip.Addr, error) {
|
||||
// Try as plain IP first
|
||||
if addr, err := netip.ParseAddr(s); err == nil {
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// Try as IP:port
|
||||
if addrPort, err := netip.ParseAddrPort(s); err == nil {
|
||||
return addrPort.Addr(), nil
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address: %s", s)
|
||||
}
|
||||
13
client/rdp/server/credprov_stub.go
Normal file
13
client/rdp/server/credprov_stub.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
// RegisterCredentialProvider is a no-op on non-Windows platforms.
|
||||
func RegisterCredentialProvider() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterCredentialProvider is a no-op on non-Windows platforms.
|
||||
func UnregisterCredentialProvider() error {
|
||||
return nil
|
||||
}
|
||||
66
client/rdp/server/credprov_windows.go
Normal file
66
client/rdp/server/credprov_windows.go
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// credProvDLLName is the filename of the credential provider DLL.
|
||||
credProvDLLName = "netbird_credprov.dll"
|
||||
)
|
||||
|
||||
// RegisterCredentialProvider registers the NetBird Credential Provider COM DLL
|
||||
// using regsvr32. The DLL must be shipped alongside the NetBird executable.
|
||||
func RegisterCredentialProvider() error {
|
||||
dllPath, err := findCredProvDLL()
|
||||
if err != nil {
|
||||
return fmt.Errorf("find credential provider DLL: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("regsvr32", "/s", dllPath)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("regsvr32 %s: %w (output: %s)", dllPath, err, string(output))
|
||||
}
|
||||
|
||||
log.Infof("registered RDP credential provider: %s", dllPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterCredentialProvider unregisters the NetBird Credential Provider COM DLL.
|
||||
func UnregisterCredentialProvider() error {
|
||||
dllPath, err := findCredProvDLL()
|
||||
if err != nil {
|
||||
log.Debugf("credential provider DLL not found for unregistration: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command("regsvr32", "/s", "/u", dllPath)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("regsvr32 /u %s: %w (output: %s)", dllPath, err, string(output))
|
||||
}
|
||||
|
||||
log.Infof("unregistered RDP credential provider: %s", dllPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// findCredProvDLL locates the credential provider DLL next to the running executable.
|
||||
func findCredProvDLL() (string, error) {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get executable path: %w", err)
|
||||
}
|
||||
|
||||
dllPath := filepath.Join(filepath.Dir(exePath), credProvDLLName)
|
||||
if _, err := os.Stat(dllPath); err != nil {
|
||||
return "", fmt.Errorf("DLL not found at %s: %w", dllPath, err)
|
||||
}
|
||||
|
||||
return dllPath, nil
|
||||
}
|
||||
184
client/rdp/server/pending.go
Normal file
184
client/rdp/server/pending.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultSessionTTL is the default time-to-live for pending RDP sessions.
|
||||
DefaultSessionTTL = 60 * time.Second
|
||||
|
||||
// cleanupInterval is how often the store checks for expired sessions.
|
||||
cleanupInterval = 10 * time.Second
|
||||
|
||||
// nonceLength is the length of the nonce in bytes.
|
||||
nonceLength = 32
|
||||
)
|
||||
|
||||
// PendingRDPSession represents an authorized but not yet consumed RDP session.
|
||||
type PendingRDPSession struct {
|
||||
SessionID string
|
||||
PeerIP netip.Addr
|
||||
OSUsername string
|
||||
Domain string
|
||||
JWTUserID string // for audit trail
|
||||
Nonce string // replay protection
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
consumed bool
|
||||
}
|
||||
|
||||
// PendingStore manages pending RDP session entries with automatic expiration.
|
||||
type PendingStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*PendingRDPSession // keyed by SessionID
|
||||
nonces map[string]struct{} // seen nonces for replay protection
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewPendingStore creates a new pending session store with the given TTL.
|
||||
func NewPendingStore(ttl time.Duration) *PendingStore {
|
||||
if ttl <= 0 {
|
||||
ttl = DefaultSessionTTL
|
||||
}
|
||||
return &PendingStore{
|
||||
sessions: make(map[string]*PendingRDPSession),
|
||||
nonces: make(map[string]struct{}),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// Add creates a new pending RDP session and returns it.
|
||||
func (ps *PendingStore) Add(peerIP netip.Addr, osUsername, domain, jwtUserID, nonce string) (*PendingRDPSession, error) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
// Check nonce for replay protection
|
||||
if _, seen := ps.nonces[nonce]; seen {
|
||||
return nil, fmt.Errorf("duplicate nonce: replay detected")
|
||||
}
|
||||
ps.nonces[nonce] = struct{}{}
|
||||
|
||||
now := time.Now()
|
||||
session := &PendingRDPSession{
|
||||
SessionID: uuid.New().String(),
|
||||
PeerIP: peerIP,
|
||||
OSUsername: osUsername,
|
||||
Domain: domain,
|
||||
JWTUserID: jwtUserID,
|
||||
Nonce: nonce,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(ps.ttl),
|
||||
}
|
||||
|
||||
ps.sessions[session.SessionID] = session
|
||||
|
||||
log.Debugf("RDP pending session created: id=%s peer=%s user=%s domain=%s expires=%s",
|
||||
session.SessionID, peerIP, osUsername, domain, session.ExpiresAt.Format(time.RFC3339))
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// QueryByPeerIP finds the first non-consumed, non-expired pending session for the given peer IP.
|
||||
func (ps *PendingStore) QueryByPeerIP(peerIP netip.Addr) (*PendingRDPSession, bool) {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
for _, session := range ps.sessions {
|
||||
if session.PeerIP == peerIP && !session.consumed && now.Before(session.ExpiresAt) {
|
||||
return session, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Consume marks a session as consumed (single-use). Returns true if the session
|
||||
// was found and successfully consumed, false if it was already consumed, expired, or not found.
|
||||
func (ps *PendingStore) Consume(sessionID string) bool {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
session, exists := ps.sessions[sessionID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if session.consumed {
|
||||
log.Debugf("RDP pending session already consumed: id=%s", sessionID)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
log.Debugf("RDP pending session expired: id=%s", sessionID)
|
||||
return false
|
||||
}
|
||||
|
||||
session.consumed = true
|
||||
log.Debugf("RDP pending session consumed: id=%s peer=%s user=%s",
|
||||
sessionID, session.PeerIP, session.OSUsername)
|
||||
return true
|
||||
}
|
||||
|
||||
// StartCleanup runs a background goroutine that periodically removes expired sessions.
|
||||
func (ps *PendingStore) StartCleanup(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
ps.cleanup()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanup removes expired and consumed sessions.
|
||||
func (ps *PendingStore) cleanup() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range ps.sessions {
|
||||
if now.After(session.ExpiresAt) || session.consumed {
|
||||
delete(ps.sessions, id)
|
||||
delete(ps.nonces, session.Nonce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of active (non-expired, non-consumed) sessions.
|
||||
func (ps *PendingStore) Count() int {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
now := time.Now()
|
||||
for _, session := range ps.sessions {
|
||||
if !session.consumed && now.Before(session.ExpiresAt) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// GenerateNonce creates a cryptographically random nonce for replay protection.
|
||||
func GenerateNonce() (string, error) {
|
||||
b := make([]byte, nonceLength)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
268
client/rdp/server/pending_test.go
Normal file
268
client/rdp/server/pending_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPendingStore_AddAndQuery(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
if session.SessionID == "" {
|
||||
t.Fatal("expected non-empty session ID")
|
||||
}
|
||||
if session.PeerIP != peerIP {
|
||||
t.Errorf("expected peer IP %s, got %s", peerIP, session.PeerIP)
|
||||
}
|
||||
if session.OSUsername != "admin" {
|
||||
t.Errorf("expected username admin, got %s", session.OSUsername)
|
||||
}
|
||||
|
||||
// Query should find the session
|
||||
found, ok := store.QueryByPeerIP(peerIP)
|
||||
if !ok {
|
||||
t.Fatal("expected to find pending session")
|
||||
}
|
||||
if found.SessionID != session.SessionID {
|
||||
t.Errorf("expected session %s, got %s", session.SessionID, found.SessionID)
|
||||
}
|
||||
|
||||
// Query for different IP should not find anything
|
||||
_, ok = store.QueryByPeerIP(netip.MustParseAddr("100.64.0.2"))
|
||||
if ok {
|
||||
t.Fatal("expected no session for different IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_Consume(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-2")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
// First consume should succeed
|
||||
if !store.Consume(session.SessionID) {
|
||||
t.Fatal("expected first consume to succeed")
|
||||
}
|
||||
|
||||
// Second consume should fail (already consumed)
|
||||
if store.Consume(session.SessionID) {
|
||||
t.Fatal("expected second consume to fail")
|
||||
}
|
||||
|
||||
// Query should no longer find consumed session
|
||||
_, ok := store.QueryByPeerIP(peerIP)
|
||||
if ok {
|
||||
t.Fatal("expected consumed session to not be found by query")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_Expiry(t *testing.T) {
|
||||
store := NewPendingStore(50 * time.Millisecond)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-3")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be found immediately
|
||||
_, ok := store.QueryByPeerIP(peerIP)
|
||||
if !ok {
|
||||
t.Fatal("expected to find session before expiry")
|
||||
}
|
||||
|
||||
// Wait for expiry
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Should not be found after expiry
|
||||
_, ok = store.QueryByPeerIP(peerIP)
|
||||
if ok {
|
||||
t.Fatal("expected session to be expired")
|
||||
}
|
||||
|
||||
// Consume should also fail
|
||||
if store.Consume(session.SessionID) {
|
||||
t.Fatal("expected consume of expired session to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_ReplayProtection(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
|
||||
if err != nil {
|
||||
t.Fatalf("first Add failed: %v", err)
|
||||
}
|
||||
|
||||
// Same nonce should be rejected
|
||||
_, err = store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
|
||||
if err == nil {
|
||||
t.Fatal("expected duplicate nonce to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_Cleanup(t *testing.T) {
|
||||
store := NewPendingStore(50 * time.Millisecond)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-cleanup")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
if store.Count() != 1 {
|
||||
t.Fatalf("expected count 1, got %d", store.Count())
|
||||
}
|
||||
|
||||
// Wait for expiry then trigger cleanup
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
store.cleanup()
|
||||
|
||||
if store.Count() != 0 {
|
||||
t.Fatalf("expected count 0 after cleanup, got %d", store.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_CleanupBackground(t *testing.T) {
|
||||
store := NewPendingStore(50 * time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
store.StartCleanup(ctx)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-bg-cleanup")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for expiry + cleanup interval
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
_, ok := store.QueryByPeerIP(peerIP)
|
||||
if ok {
|
||||
t.Fatal("expected session to be cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_ConcurrentAccess(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
ip := netip.AddrFrom4([4]byte{100, 64, byte(i / 256), byte(i % 256)})
|
||||
nonce := "nonce-" + string(rune(i+'A'))
|
||||
if i >= 26 {
|
||||
nonce = "nonce-" + string(rune(i-26+'a'))
|
||||
}
|
||||
|
||||
session, err := store.Add(ip, "admin", ".", "user", nonce)
|
||||
if err != nil {
|
||||
return // nonce collision in test is expected
|
||||
}
|
||||
|
||||
store.QueryByPeerIP(ip)
|
||||
store.Consume(session.SessionID)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPendingStore_MultipleSessions(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
ip1 := netip.MustParseAddr("100.64.0.1")
|
||||
ip2 := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
s1, err := store.Add(ip1, "admin", ".", "user1", "nonce-a")
|
||||
if err != nil {
|
||||
t.Fatalf("Add s1 failed: %v", err)
|
||||
}
|
||||
|
||||
s2, err := store.Add(ip2, "jdoe", "DOMAIN", "user2", "nonce-b")
|
||||
if err != nil {
|
||||
t.Fatalf("Add s2 failed: %v", err)
|
||||
}
|
||||
|
||||
// Query each
|
||||
found1, ok := store.QueryByPeerIP(ip1)
|
||||
if !ok || found1.SessionID != s1.SessionID {
|
||||
t.Fatal("expected to find s1")
|
||||
}
|
||||
|
||||
found2, ok := store.QueryByPeerIP(ip2)
|
||||
if !ok || found2.SessionID != s2.SessionID {
|
||||
t.Fatal("expected to find s2")
|
||||
}
|
||||
|
||||
if found2.Domain != "DOMAIN" {
|
||||
t.Errorf("expected domain DOMAIN, got %s", found2.Domain)
|
||||
}
|
||||
|
||||
if store.Count() != 2 {
|
||||
t.Errorf("expected count 2, got %d", store.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateNonce(t *testing.T) {
|
||||
nonce1, err := GenerateNonce()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateNonce failed: %v", err)
|
||||
}
|
||||
|
||||
nonce2, err := GenerateNonce()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateNonce failed: %v", err)
|
||||
}
|
||||
|
||||
if len(nonce1) != nonceLength*2 { // hex encoding doubles the length
|
||||
t.Errorf("expected nonce length %d, got %d", nonceLength*2, len(nonce1))
|
||||
}
|
||||
|
||||
if nonce1 == nonce2 {
|
||||
t.Error("expected unique nonces")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWindowsUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedUser string
|
||||
expectedDomain string
|
||||
}{
|
||||
{"admin", "admin", "."},
|
||||
{"DOMAIN\\admin", "admin", "DOMAIN"},
|
||||
{"admin@domain.com", "admin", "domain.com"},
|
||||
{".\\localuser", "localuser", "."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
user, domain := parseWindowsUsername(tt.input)
|
||||
if user != tt.expectedUser {
|
||||
t.Errorf("parseWindowsUsername(%q) user = %q, want %q", tt.input, user, tt.expectedUser)
|
||||
}
|
||||
if domain != tt.expectedDomain {
|
||||
t.Errorf("parseWindowsUsername(%q) domain = %q, want %q", tt.input, domain, tt.expectedDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
19
client/rdp/server/pipe_stub.go
Normal file
19
client/rdp/server/pipe_stub.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import "context"
|
||||
|
||||
type stubPipeServer struct{}
|
||||
|
||||
func newPipeServer(_ *PendingStore) PipeServer {
|
||||
return &stubPipeServer{}
|
||||
}
|
||||
|
||||
func (s *stubPipeServer) Start(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubPipeServer) Stop() error {
|
||||
return nil
|
||||
}
|
||||
164
client/rdp/server/pipe_windows.go
Normal file
164
client/rdp/server/pipe_windows.go
Normal file
@@ -0,0 +1,164 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
)
|
||||
|
||||
const (
|
||||
// PipeName is the named pipe path used for IPC between the NetBird agent and
|
||||
// the Credential Provider DLL.
|
||||
PipeName = `\\.\pipe\netbird-rdp-auth`
|
||||
|
||||
// pipeSDDL restricts access to LOCAL_SYSTEM (SY) and Administrators (BA).
|
||||
pipeSDDL = "D:P(A;;GA;;;SY)(A;;GA;;;BA)"
|
||||
|
||||
// maxPipeRequestSize is the maximum size of a pipe request in bytes.
|
||||
maxPipeRequestSize = 4096
|
||||
)
|
||||
|
||||
// windowsPipeServer implements the PipeServer interface for Windows.
|
||||
type windowsPipeServer struct {
|
||||
pending *PendingStore
|
||||
listener net.Listener
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newPipeServer(pending *PendingStore) PipeServer {
|
||||
return &windowsPipeServer{
|
||||
pending: pending,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) Start(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
|
||||
cfg := &winio.PipeConfig{
|
||||
SecurityDescriptor: pipeSDDL,
|
||||
}
|
||||
|
||||
listener, err := winio.ListenPipe(PipeName, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
go s.acceptLoop()
|
||||
|
||||
log.Infof("RDP named pipe server started on %s", PipeName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("RDP pipe accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handlePipeConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) handlePipeConnection(conn net.Conn) {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("RDP pipe close: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(conn, maxPipeRequestSize))
|
||||
if err != nil {
|
||||
log.Debugf("RDP pipe read: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var req PipeRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
log.Debugf("RDP pipe unmarshal: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var resp PipeResponse
|
||||
|
||||
switch req.Action {
|
||||
case PipeActionQuery:
|
||||
resp = s.handleQuery(req.RemoteIP)
|
||||
case PipeActionConsume:
|
||||
resp = s.handleConsume(req.SessionID)
|
||||
default:
|
||||
log.Debugf("RDP pipe unknown action: %s", req.Action)
|
||||
return
|
||||
}
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Debugf("RDP pipe marshal response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := conn.Write(respData); err != nil {
|
||||
log.Debugf("RDP pipe write response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) handleQuery(remoteIP string) PipeResponse {
|
||||
peerIP, err := parseAddr(remoteIP)
|
||||
if err != nil {
|
||||
log.Debugf("RDP pipe invalid remote IP: %s", remoteIP)
|
||||
return PipeResponse{Found: false}
|
||||
}
|
||||
|
||||
session, found := s.pending.QueryByPeerIP(peerIP)
|
||||
if !found {
|
||||
return PipeResponse{Found: false}
|
||||
}
|
||||
|
||||
return PipeResponse{
|
||||
Found: true,
|
||||
SessionID: session.SessionID,
|
||||
OSUser: session.OSUsername,
|
||||
Domain: session.Domain,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) handleConsume(sessionID string) PipeResponse {
|
||||
if s.pending.Consume(sessionID) {
|
||||
return PipeResponse{Found: true, SessionID: sessionID}
|
||||
}
|
||||
return PipeResponse{Found: false}
|
||||
}
|
||||
48
client/rdp/server/protocol.go
Normal file
48
client/rdp/server/protocol.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package server
|
||||
|
||||
// AuthRequest is the sideband authorization request sent by the connecting peer
|
||||
// to the target peer's RDP auth server over the WireGuard tunnel.
|
||||
type AuthRequest struct {
|
||||
JWTToken string `json:"jwt_token"`
|
||||
RequestedUser string `json:"requested_user"`
|
||||
ClientPeerIP string `json:"client_peer_ip"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// AuthResponse is the sideband authorization response sent by the target peer
|
||||
// back to the connecting peer.
|
||||
type AuthResponse struct {
|
||||
Status string `json:"status"` // "authorized" or "denied"
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"` // unix timestamp
|
||||
OSUser string `json:"os_user,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// PipeRequest is the IPC request from the Credential Provider DLL to the NetBird agent
|
||||
// via the named pipe.
|
||||
type PipeRequest struct {
|
||||
Action string `json:"action"` // "query_pending" or "consume"
|
||||
RemoteIP string `json:"remote_ip"` // connecting peer's WG IP
|
||||
SessionID string `json:"session_id,omitempty"` // for consume action
|
||||
}
|
||||
|
||||
// PipeResponse is the IPC response from the NetBird agent to the Credential Provider DLL.
|
||||
type PipeResponse struct {
|
||||
Found bool `json:"found"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
OSUser string `json:"os_user,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// StatusAuthorized indicates the RDP session was authorized.
|
||||
StatusAuthorized = "authorized"
|
||||
// StatusDenied indicates the RDP session was denied.
|
||||
StatusDenied = "denied"
|
||||
|
||||
// PipeActionQuery queries for a pending session by remote IP.
|
||||
PipeActionQuery = "query_pending"
|
||||
// PipeActionConsume marks a pending session as consumed.
|
||||
PipeActionConsume = "consume"
|
||||
)
|
||||
301
client/rdp/server/server.go
Normal file
301
client/rdp/server/server.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
// InternalRDPAuthPort is the port the sideband auth server listens on.
|
||||
InternalRDPAuthPort = 22338
|
||||
|
||||
// DefaultRDPAuthPort is the external port on the WireGuard interface (DNAT target).
|
||||
DefaultRDPAuthPort = 22338
|
||||
|
||||
// maxRequestSize is the maximum size of an auth request in bytes.
|
||||
maxRequestSize = 64 * 1024
|
||||
|
||||
// connectionTimeout is the timeout for a single auth connection.
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// JWTValidator validates JWT tokens and extracts user identity.
|
||||
type JWTValidator interface {
|
||||
ValidateAndExtract(token string) (userID string, err error)
|
||||
}
|
||||
|
||||
// Authorizer checks if a user is authorized for RDP access.
|
||||
type Authorizer interface {
|
||||
Authorize(jwtUserID, osUsername string) (string, error)
|
||||
}
|
||||
|
||||
// Server is the sideband RDP authorization server that listens on the WireGuard interface.
|
||||
type Server struct {
|
||||
listener net.Listener
|
||||
pending *PendingStore
|
||||
pipeServer PipeServer
|
||||
jwtValidator JWTValidator
|
||||
authorizer Authorizer
|
||||
sshAuthorizer *sshauth.Authorizer // reuses SSH ACL for RDP access control
|
||||
networkAddr netip.Prefix // WireGuard network for source IP validation
|
||||
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// PipeServer is the interface for the named pipe IPC server (platform-specific).
|
||||
type PipeServer interface {
|
||||
Start(ctx context.Context) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// Config holds the configuration for the RDP auth server.
|
||||
type Config struct {
|
||||
JWTValidator JWTValidator
|
||||
Authorizer Authorizer
|
||||
NetworkAddr netip.Prefix
|
||||
SessionTTL time.Duration
|
||||
}
|
||||
|
||||
// New creates a new RDP sideband auth server.
|
||||
func New(cfg *Config) *Server {
|
||||
ttl := cfg.SessionTTL
|
||||
if ttl <= 0 {
|
||||
ttl = DefaultSessionTTL
|
||||
}
|
||||
|
||||
pending := NewPendingStore(ttl)
|
||||
|
||||
return &Server{
|
||||
pending: pending,
|
||||
pipeServer: newPipeServer(pending),
|
||||
jwtValidator: cfg.JWTValidator,
|
||||
authorizer: cfg.Authorizer,
|
||||
sshAuthorizer: sshauth.NewAuthorizer(),
|
||||
networkAddr: cfg.NetworkAddr,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateRDPAuth updates the RDP authorization config (reuses SSH ACL).
|
||||
func (s *Server) UpdateRDPAuth(config *sshauth.Config) {
|
||||
s.sshAuthorizer.Update(config)
|
||||
log.Debugf("RDP auth: updated authorization config")
|
||||
}
|
||||
|
||||
// Start begins listening for sideband auth requests on the given address.
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return errors.New("RDP auth server already running")
|
||||
}
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
|
||||
listenAddr := net.TCPAddrFromAddrPort(addr)
|
||||
listener, err := net.ListenTCP("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
s.pending.StartCleanup(s.ctx)
|
||||
|
||||
if s.pipeServer != nil {
|
||||
if err := s.pipeServer.Start(s.ctx); err != nil {
|
||||
log.Warnf("failed to start RDP named pipe server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
go s.acceptLoop()
|
||||
|
||||
log.Infof("RDP sideband auth server started on %s", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shuts down the server and cleans up resources.
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
if s.pipeServer != nil {
|
||||
if err := s.pipeServer.Stop(); err != nil {
|
||||
log.Warnf("failed to stop RDP named pipe server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("close listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("RDP sideband auth server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPendingStore returns the pending session store (for testing/named pipe access).
|
||||
func (s *Server) GetPendingStore() *PendingStore {
|
||||
return s.pending
|
||||
}
|
||||
|
||||
func (s *Server) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("RDP auth accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("RDP auth close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(connectionTimeout)); err != nil {
|
||||
log.Debugf("RDP auth set deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate source IP is from WireGuard network
|
||||
remoteAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
log.Debugf("RDP auth parse remote addr: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.networkAddr.Contains(remoteAddr.Addr()) {
|
||||
log.Warnf("RDP auth rejected connection from non-WG address: %s", remoteAddr.Addr())
|
||||
return
|
||||
}
|
||||
|
||||
// Read request
|
||||
data, err := io.ReadAll(io.LimitReader(conn, maxRequestSize))
|
||||
if err != nil {
|
||||
log.Debugf("RDP auth read request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var req AuthRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
log.Debugf("RDP auth unmarshal request: %v", err)
|
||||
s.sendResponse(conn, &AuthResponse{Status: StatusDenied, Reason: "invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
response := s.processAuthRequest(remoteAddr.Addr(), &req)
|
||||
s.sendResponse(conn, response)
|
||||
}
|
||||
|
||||
func (s *Server) processAuthRequest(peerIP netip.Addr, req *AuthRequest) *AuthResponse {
|
||||
// Validate JWT
|
||||
if s.jwtValidator == nil {
|
||||
// No JWT validation configured - for POC, accept all requests from WG peers
|
||||
log.Warnf("RDP auth: no JWT validator configured, accepting request from %s", peerIP)
|
||||
return s.createSession(peerIP, req, "no-jwt-validation")
|
||||
}
|
||||
|
||||
userID, err := s.jwtValidator.ValidateAndExtract(req.JWTToken)
|
||||
if err != nil {
|
||||
log.Warnf("RDP auth JWT validation failed for %s: %v", peerIP, err)
|
||||
return &AuthResponse{Status: StatusDenied, Reason: "JWT validation failed"}
|
||||
}
|
||||
|
||||
// Check authorization - try explicit authorizer first, then SSH ACL
|
||||
if s.authorizer != nil {
|
||||
if _, err := s.authorizer.Authorize(userID, req.RequestedUser); err != nil {
|
||||
log.Warnf("RDP auth denied for user %s -> %s: %v", userID, req.RequestedUser, err)
|
||||
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
|
||||
}
|
||||
} else if s.sshAuthorizer != nil {
|
||||
if _, err := s.sshAuthorizer.Authorize(userID, req.RequestedUser); err != nil {
|
||||
log.Warnf("RDP auth denied (SSH ACL) for user %s -> %s: %v", userID, req.RequestedUser, err)
|
||||
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
|
||||
}
|
||||
}
|
||||
|
||||
return s.createSession(peerIP, req, userID)
|
||||
}
|
||||
|
||||
func (s *Server) createSession(peerIP netip.Addr, req *AuthRequest, jwtUserID string) *AuthResponse {
|
||||
// Parse domain from requested user (DOMAIN\user or user@domain)
|
||||
osUser, domain := parseWindowsUsername(req.RequestedUser)
|
||||
|
||||
session, err := s.pending.Add(peerIP, osUser, domain, jwtUserID, req.Nonce)
|
||||
if err != nil {
|
||||
log.Warnf("RDP auth create session failed: %v", err)
|
||||
return &AuthResponse{Status: StatusDenied, Reason: err.Error()}
|
||||
}
|
||||
|
||||
return &AuthResponse{
|
||||
Status: StatusAuthorized,
|
||||
SessionID: session.SessionID,
|
||||
ExpiresAt: session.ExpiresAt.Unix(),
|
||||
OSUser: session.OSUsername,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sendResponse(conn net.Conn, resp *AuthResponse) {
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Debugf("RDP auth marshal response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
log.Debugf("RDP auth write response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// parseWindowsUsername extracts username and domain from Windows username formats.
|
||||
// Supports DOMAIN\username, username@domain, and plain username.
|
||||
func parseWindowsUsername(fullUsername string) (username, domain string) {
|
||||
for i := len(fullUsername) - 1; i >= 0; i-- {
|
||||
if fullUsername[i] == '\\' {
|
||||
return fullUsername[i+1:], fullUsername[:i]
|
||||
}
|
||||
}
|
||||
|
||||
if idx := indexOf(fullUsername, '@'); idx != -1 {
|
||||
return fullUsername[:idx], fullUsername[idx+1:]
|
||||
}
|
||||
|
||||
return fullUsername, "."
|
||||
}
|
||||
|
||||
func indexOf(s string, c byte) int {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == c {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
@@ -366,6 +366,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||
config.ServerRDPAllowed = msg.ServerRDPAllowed
|
||||
config.NetworkMonitor = msg.NetworkMonitor
|
||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||
@@ -1359,6 +1360,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
}
|
||||
|
||||
if engine.IsBlockInbound() {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return gstatus.Errorf(codes.Internal, "expose manager not available")
|
||||
@@ -1510,6 +1515,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
ServerRDPAllowed: cfg.ServerRDPAllowed != nil && *cfg.ServerRDPAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
|
||||
@@ -58,6 +58,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
serverRDPAllowed := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
@@ -82,6 +83,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
ServerRDPAllowed: &serverRDPAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
@@ -125,6 +127,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.NotNil(t, cfg.ServerRDPAllowed)
|
||||
require.Equal(t, serverRDPAllowed, *cfg.ServerRDPAllowed)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
@@ -176,6 +180,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"ServerRDPAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
@@ -236,6 +241,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"allow-server-rdp": "ServerRDPAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
// registerStates registers all states that need crash recovery cleanup.
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
// registerStates registers all states that need crash recovery cleanup.
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
|
||||
@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
|
||||
if err != nil {
|
||||
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
}
|
||||
|
||||
if hasCommand {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
if err := serverSession.Run(session.RawCommand()); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
|
||||
@@ -284,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) {
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer == nil {
|
||||
sshServer := s.sshServer
|
||||
if sshServer == nil {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := s.sshServer.Close(); err != nil {
|
||||
// Close outside the lock: session handlers need s.mu for unregisterSession.
|
||||
if err := sshServer.Close(); err != nil {
|
||||
log.Debugf("close SSH server: %v", err)
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
|
||||
s.mu.Lock()
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
maps.Clear(s.connections)
|
||||
@@ -307,6 +309,7 @@ func (s *Server) Stop() error {
|
||||
}
|
||||
}
|
||||
maps.Clear(s.remoteForwardListeners)
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
if isPty && !hasCommand {
|
||||
// ssh <host> - PTY interactive session (login)
|
||||
|
||||
@@ -153,6 +153,9 @@ func networkAddresses() ([]NetworkAddress, error) {
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if iface.HardwareAddr.String() == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info {
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
|
||||
addrs, err := networkAddresses()
|
||||
if err != nil {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
return &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: osInfo[0],
|
||||
Platform: runtime.GOARCH,
|
||||
OS: osName,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
Environment: env,
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: osInfo[0],
|
||||
Platform: runtime.GOARCH,
|
||||
OS: osName,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
NetworkAddresses: addrs,
|
||||
Environment: env,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,9 +24,10 @@ import (
|
||||
|
||||
// Initial state for the debug collection
|
||||
type debugInitialState struct {
|
||||
wasDown bool
|
||||
logLevel proto.LogLevel
|
||||
isLevelTrace bool
|
||||
wasDown bool
|
||||
needsRestoreUp bool
|
||||
logLevel proto.LogLevel
|
||||
isLevelTrace bool
|
||||
}
|
||||
|
||||
// Debug collection parameters
|
||||
@@ -371,46 +372,51 @@ func (s *serviceClient) configureServiceForDebug(
|
||||
conn proto.DaemonServiceClient,
|
||||
state *debugInitialState,
|
||||
enablePersistence bool,
|
||||
) error {
|
||||
) {
|
||||
if state.wasDown {
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service up: %v", err)
|
||||
log.Warnf("failed to bring service up: %v", err)
|
||||
} else {
|
||||
log.Info("Service brought up for debug")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
log.Info("Service brought up for debug")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
|
||||
if !state.isLevelTrace {
|
||||
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
|
||||
return fmt.Errorf("set log level to TRACE: %v", err)
|
||||
log.Warnf("failed to set log level to TRACE: %v", err)
|
||||
} else {
|
||||
log.Info("Log level set to TRACE for debug")
|
||||
}
|
||||
log.Info("Log level set to TRACE for debug")
|
||||
}
|
||||
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service down: %v", err)
|
||||
log.Warnf("failed to bring service down: %v", err)
|
||||
} else {
|
||||
state.needsRestoreUp = !state.wasDown
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
|
||||
if enablePersistence {
|
||||
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("enable sync response persistence: %v", err)
|
||||
log.Warnf("failed to enable sync response persistence: %v", err)
|
||||
} else {
|
||||
log.Info("Sync response persistence enabled for debug")
|
||||
}
|
||||
log.Info("Sync response persistence enabled for debug")
|
||||
}
|
||||
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("bring service back up: %v", err)
|
||||
log.Warnf("failed to bring service back up: %v", err)
|
||||
} else {
|
||||
state.needsRestoreUp = false
|
||||
time.Sleep(time.Second * 3)
|
||||
}
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
|
||||
log.Warnf("failed to start CPU profiling: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) collectDebugData(
|
||||
@@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var wg sync.WaitGroup
|
||||
startProgressTracker(ctx, &wg, params.duration, progress)
|
||||
|
||||
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
|
||||
return err
|
||||
}
|
||||
s.configureServiceForDebug(conn, state, params.enablePersistence)
|
||||
|
||||
wg.Wait()
|
||||
progress.progressBar.Hide()
|
||||
@@ -482,9 +486,17 @@ func (s *serviceClient) createDebugBundleFromCollection(
|
||||
|
||||
// Restore service to original state
|
||||
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
|
||||
if state.needsRestoreUp {
|
||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
log.Warnf("failed to restore up state: %v", err)
|
||||
} else {
|
||||
log.Info("Service state restored to up")
|
||||
}
|
||||
}
|
||||
|
||||
if state.wasDown {
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
log.Errorf("Failed to restore down state: %v", err)
|
||||
log.Warnf("failed to restore down state: %v", err)
|
||||
} else {
|
||||
log.Info("Service state restored to down")
|
||||
}
|
||||
@@ -492,7 +504,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat
|
||||
|
||||
if !state.isLevelTrace {
|
||||
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
|
||||
log.Errorf("Failed to restore log level: %v", err)
|
||||
log.Warnf("failed to restore log level: %v", err)
|
||||
} else {
|
||||
log.Info("Log level restored to original setting")
|
||||
}
|
||||
|
||||
@@ -179,9 +179,11 @@ type StoreConfig struct {
|
||||
|
||||
// ReverseProxyConfig contains reverse proxy settings
|
||||
type ReverseProxyConfig struct {
|
||||
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
|
||||
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
|
||||
TrustedPeers []string `yaml:"trustedPeers"`
|
||||
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
|
||||
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
|
||||
TrustedPeers []string `yaml:"trustedPeers"`
|
||||
AccessLogRetentionDays int `yaml:"accessLogRetentionDays"`
|
||||
AccessLogCleanupIntervalHours int `yaml:"accessLogCleanupIntervalHours"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a CombinedConfig with default values
|
||||
@@ -645,7 +647,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
|
||||
|
||||
// Build reverse proxy config
|
||||
reverseProxy := nbconfig.ReverseProxy{
|
||||
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
|
||||
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
|
||||
AccessLogRetentionDays: mgmt.ReverseProxy.AccessLogRetentionDays,
|
||||
AccessLogCleanupIntervalHours: mgmt.ReverseProxy.AccessLogCleanupIntervalHours,
|
||||
}
|
||||
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
|
||||
if prefix, err := netip.ParsePrefix(p); err == nil {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user