[client] Add RDP token passthrough for passwordless Windows Remote Desktop

Implement sideband authorization and credential provider architecture for
passwordless RDP access to Windows peers via NetBird.

Go components:
- Sideband RDP auth server (TCP on WG interface, port 3390/22023)
- Pending session store with TTL expiry and replay protection
- Named pipe IPC server (\\.\pipe\netbird-rdp-auth) for credential provider
- Sideband client for connecting peer to request authorization
- CLI command `netbird rdp [user@]host` with JWT auth flow
- Engine integration with DNAT port redirection

Rust credential provider DLL (client/rdp/credprov/):
- COM DLL implementing ICredentialProvider + ICredentialProviderCredential
- Loaded by Windows LogonUI.exe at the RDP login screen
- Queries NetBird agent via named pipe for pending sessions
- Performs S4U logon (LsaLogonUser) for passwordless Windows token creation
- Self-registration via regsvr32 (DllRegisterServer/DllUnregisterServer)

https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
This commit is contained in:
Claude
2026-04-11 17:15:42 +00:00
parent 5259e5df51
commit c5186f1483
21 changed files with 2883 additions and 0 deletions

269
client/cmd/rdp.go Normal file
View File

@@ -0,0 +1,269 @@
package cmd
import (
"context"
"errors"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
rdpclient "github.com/netbirdio/netbird/client/rdp/client"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
var (
rdpUsername string
rdpHost string
rdpNoBrowser bool
rdpNoCache bool
)
func init() {
rdpCmd.PersistentFlags().StringVarP(&rdpUsername, "user", "u", "", "Windows username on remote peer")
rdpCmd.PersistentFlags().BoolVar(&rdpNoBrowser, noBrowserFlag, false, noBrowserDesc)
rdpCmd.PersistentFlags().BoolVar(&rdpNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
}
var rdpCmd = &cobra.Command{
Use: "rdp [flags] [user@]host",
Short: "Connect to a NetBird peer via RDP (passwordless)",
Long: `Connect to a NetBird peer using Remote Desktop Protocol with token-based
passwordless authentication. The target peer must have RDP passthrough enabled.
This command:
1. Obtains a JWT token via OIDC authentication
2. Sends the token to the target peer's sideband auth service
3. If authorized, launches mstsc.exe to connect
Examples:
netbird rdp peer-hostname
netbird rdp administrator@peer-hostname
netbird rdp --user admin peer-hostname`,
Args: cobra.MinimumNArgs(1),
RunE: rdpFn,
}
func rdpFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
// Parse user@host
if err := parseRDPHostArg(args[0]); err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
rdpCtx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runRDP(rdpCtx, cmd); err != nil {
errCh <- err
}
cancel()
}()
select {
case <-sig:
cancel()
<-rdpCtx.Done()
return nil
case err := <-errCh:
return err
case <-rdpCtx.Done():
}
return nil
}
func parseRDPHostArg(arg string) error {
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
if rdpUsername == "" {
rdpUsername = parts[0]
}
rdpHost = parts[1]
} else {
rdpHost = arg
}
if rdpUsername == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
rdpUsername = sudoUser
} else if currentUser, err := user.Current(); err == nil {
rdpUsername = currentUser.Username
} else {
rdpUsername = "Administrator"
}
}
return nil
}
func runRDP(ctx context.Context, cmd *cobra.Command) error {
// Connect to daemon
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() { _ = grpcConn.Close() }()
daemonClient := proto.NewDaemonServiceClient(grpcConn)
// Resolve peer IP
peerIP, err := resolvePeerIP(ctx, daemonClient, rdpHost)
if err != nil {
return fmt.Errorf("resolve peer %s: %w", rdpHost, err)
}
cmd.Printf("Connecting to %s@%s (%s)...\n", rdpUsername, rdpHost, peerIP)
// Obtain JWT token
hint := profilemanager.GetLoginHint()
var browserOpener func(string) error
if !rdpNoBrowser {
browserOpener = util.OpenBrowser
}
jwtToken, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !rdpNoCache, hint, browserOpener)
if err != nil {
return fmt.Errorf("JWT authentication: %w", err)
}
log.Debug("JWT authentication successful")
cmd.Println("Authenticated. Requesting RDP access...")
// Generate nonce for replay protection
nonce, err := rdpserver.GenerateNonce()
if err != nil {
return fmt.Errorf("generate nonce: %w", err)
}
// Send sideband auth request
authClient := rdpclient.New()
authAddr := net.JoinHostPort(peerIP, fmt.Sprintf("%d", rdpserver.DefaultRDPAuthPort))
resp, err := authClient.RequestAuth(ctx, authAddr, &rdpserver.AuthRequest{
JWTToken: jwtToken,
RequestedUser: rdpUsername,
ClientPeerIP: "", // will be filled by the server from the connection
Nonce: nonce,
})
if err != nil {
cmd.Printf("Failed to authorize RDP session with %s\n", rdpHost)
cmd.Printf("\nTroubleshooting:\n")
cmd.Printf(" 1. Check connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify RDP passthrough is enabled on the target peer\n")
return fmt.Errorf("sideband auth: %w", err)
}
if resp.Status != rdpserver.StatusAuthorized {
return fmt.Errorf("RDP access denied: %s", resp.Reason)
}
cmd.Printf("RDP access authorized (session: %s, user: %s)\n", resp.SessionID, resp.OSUser)
cmd.Printf("Launching Remote Desktop client...\n")
// Launch mstsc.exe (platform-specific)
if err := launchRDPClient(peerIP); err != nil {
return fmt.Errorf("launch RDP client: %w", err)
}
return nil
}
// resolvePeerIP resolves a peer hostname/FQDN to its WireGuard IP address
// by querying the daemon for the current peer status.
func resolvePeerIP(ctx context.Context, client proto.DaemonServiceClient, peerAddress string) (string, error) {
statusResp, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return "", fmt.Errorf("get daemon status: %w", err)
}
if statusResp.GetFullStatus() == nil {
return "", errors.New("daemon returned empty status")
}
for _, peer := range statusResp.GetFullStatus().GetPeers() {
if matchesPeer(peer, peerAddress) {
ip := peer.GetIP()
if ip == "" {
continue
}
// Strip CIDR suffix if present
if idx := strings.Index(ip, "/"); idx != -1 {
ip = ip[:idx]
}
return ip, nil
}
}
// If not found as a peer name, try as a direct IP
if addr, err := net.ResolveIPAddr("ip", peerAddress); err == nil {
return addr.String(), nil
}
return "", fmt.Errorf("peer %q not found in network", peerAddress)
}
func matchesPeer(peer *proto.PeerState, address string) bool {
address = strings.ToLower(address)
if strings.EqualFold(peer.GetFqdn(), address) {
return true
}
// Match against FQDN without trailing dot
fqdn := strings.TrimSuffix(peer.GetFqdn(), ".")
if strings.EqualFold(fqdn, address) {
return true
}
// Match against short hostname (first part of FQDN)
if parts := strings.SplitN(fqdn, ".", 2); len(parts) > 0 {
if strings.EqualFold(parts[0], address) {
return true
}
}
// Match against IP
ip := peer.GetIP()
if idx := strings.Index(ip, "/"); idx != -1 {
ip = ip[:idx]
}
if ip == address {
return true
}
return false
}

13
client/cmd/rdp_stub.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows
package cmd
import "fmt"
// launchRDPClient is a stub for non-Windows platforms.
func launchRDPClient(peerIP string) error {
fmt.Printf("RDP session authorized for %s\n", peerIP)
fmt.Println("Note: mstsc.exe is only available on Windows.")
fmt.Printf("Use any RDP client to connect to %s:3389\n", peerIP)
return nil
}

34
client/cmd/rdp_windows.go Normal file
View File

@@ -0,0 +1,34 @@
//go:build windows
package cmd
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
// launchRDPClient launches the native Windows Remote Desktop client (mstsc.exe).
func launchRDPClient(peerIP string) error {
mstscPath, err := exec.LookPath("mstsc.exe")
if err != nil {
return fmt.Errorf("mstsc.exe not found: %w", err)
}
cmd := exec.Command(mstscPath, fmt.Sprintf("/v:%s", peerIP))
if err := cmd.Start(); err != nil {
return fmt.Errorf("start mstsc.exe: %w", err)
}
log.Debugf("launched mstsc.exe (PID %d) connecting to %s", cmd.Process.Pid, peerIP)
// Don't wait for mstsc to exit - it runs independently
go func() {
if err := cmd.Wait(); err != nil {
log.Debugf("mstsc.exe exited: %v", err)
}
}()
return nil
}

View File

@@ -150,6 +150,7 @@ func init() {
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(rdpCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)

View File

@@ -197,6 +197,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
rdpServer rdpServer
statusRecorder *peer.Status

View File

@@ -0,0 +1,123 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
)
type rdpServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetPendingStore() *rdpserver.PendingStore
}
func (e *Engine) setupRDPPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
return fmt.Errorf("add RDP auth port redirection: %w", err)
}
log.Infof("RDP auth port redirection enabled: %s:%d -> %s:%d",
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
return nil
}
func (e *Engine) cleanupRDPPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
return fmt.Errorf("remove RDP auth port redirection: %w", err)
}
log.Debugf("RDP auth port redirection removed: %s:%d -> %s:%d",
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
return nil
}
func (e *Engine) startRDPServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
wgAddr := e.wgInterface.Address()
cfg := &rdpserver.Config{
NetworkAddr: wgAddr.Network,
}
server := rdpserver.New(cfg)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, rdpserver.InternalRDPAuthPort)
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start RDP auth server: %w", err)
}
e.rdpServer = server
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
log.Debugf("registered RDP auth service with netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
}
}
if err := e.setupRDPPortRedirection(); err != nil {
log.Warnf("failed to setup RDP auth port redirection: %v", err)
}
return nil
}
func (e *Engine) stopRDPServer() error {
if e.rdpServer == nil {
return nil
}
if err := e.cleanupRDPPortRedirection(); err != nil {
log.Warnf("failed to cleanup RDP auth port redirection: %v", err)
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
log.Debugf("unregistered RDP auth service from netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
}
}
log.Info("stopping RDP auth server")
err := e.rdpServer.Stop()
e.rdpServer = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}

View File

@@ -0,0 +1,88 @@
package client
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"time"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
)
const (
// DefaultTimeout is the default timeout for sideband auth requests.
DefaultTimeout = 30 * time.Second
// maxResponseSize is the maximum size of an auth response in bytes.
maxResponseSize = 64 * 1024
)
// Client connects to a target peer's RDP sideband auth server to request access.
type Client struct {
Timeout time.Duration
}
// New creates a new sideband RDP auth client.
func New() *Client {
return &Client{
Timeout: DefaultTimeout,
}
}
// RequestAuth sends an authorization request to the target peer's sideband server
// and returns the response. The addr should be in "host:port" format.
func (c *Client) RequestAuth(ctx context.Context, addr string, req *rdpserver.AuthRequest) (*rdpserver.AuthResponse, error) {
timeout := c.Timeout
if timeout <= 0 {
timeout = DefaultTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("connect to RDP auth server at %s: %w", addr, err)
}
defer func() { _ = conn.Close() }()
deadline, ok := ctx.Deadline()
if ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set connection deadline: %w", err)
}
}
// Send request
reqData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal auth request: %w", err)
}
if _, err := conn.Write(reqData); err != nil {
return nil, fmt.Errorf("send auth request: %w", err)
}
// Signal we're done writing so the server can read the full request
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := tcpConn.CloseWrite(); err != nil {
return nil, fmt.Errorf("close write: %w", err)
}
}
// Read response
respData, err := io.ReadAll(io.LimitReader(conn, maxResponseSize))
if err != nil {
return nil, fmt.Errorf("read auth response: %w", err)
}
var resp rdpserver.AuthResponse
if err := json.Unmarshal(respData, &resp); err != nil {
return nil, fmt.Errorf("unmarshal auth response: %w", err)
}
return &resp, nil
}

View File

@@ -0,0 +1,31 @@
[package]
name = "netbird-credprov"
version = "0.1.0"
edition = "2021"
description = "NetBird RDP Credential Provider for Windows"
license = "BSD-3-Clause"
[lib]
crate-type = ["cdylib"]
[dependencies]
windows = { version = "0.58", features = [
"implement",
"Win32_Foundation",
"Win32_System_Com",
"Win32_UI_Shell",
"Win32_Security",
"Win32_Security_Authentication_Identity",
"Win32_Security_Credentials",
"Win32_System_RemoteDesktop",
"Win32_System_Threading",
] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
uuid = { version = "1", features = ["v4"] }
log = "0.4"
[profile.release]
opt-level = "s"
lto = true
strip = true

View File

@@ -0,0 +1,210 @@
//! ICredentialProviderCredential implementation.
//!
//! Represents a single "NetBird Login" credential tile on the Windows login screen.
//! When selected, it queries the local NetBird agent for pending RDP sessions and
//! performs S4U logon to authenticate the user without a password.
use crate::named_pipe_client::{NamedPipeClient, PipeResponse};
use crate::s4u;
use std::sync::Mutex;
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::Security::Credentials::*;
use windows::Win32::UI::Shell::*;
/// NetBird credential tile that appears on the Windows login screen.
#[implement(ICredentialProviderCredential)]
pub struct NetBirdCredential {
/// The pending session information from the NetBird agent.
session: Mutex<Option<PipeResponse>>,
/// The remote IP address of the connecting peer.
remote_ip: Mutex<String>,
}
impl NetBirdCredential {
pub fn new(remote_ip: String, session: PipeResponse) -> Self {
Self {
session: Mutex::new(Some(session)),
remote_ip: Mutex::new(remote_ip),
}
}
}
impl ICredentialProviderCredential_Impl for NetBirdCredential_Impl {
fn Advise(&self, _pcpce: Option<&ICredentialProviderCredentialEvents>) -> Result<()> {
Ok(())
}
fn UnAdvise(&self) -> Result<()> {
Ok(())
}
fn SetSelected(&self, _pbautologon: *mut BOOL) -> Result<()> {
// Auto-logon when this credential is selected
unsafe {
if !_pbautologon.is_null() {
*_pbautologon = TRUE;
}
}
Ok(())
}
fn SetDeselected(&self) -> Result<()> {
Ok(())
}
fn GetFieldState(
&self,
_dwfieldid: u32,
_pcpfs: *mut CREDENTIAL_PROVIDER_FIELD_STATE,
_pcpfis: *mut CREDENTIAL_PROVIDER_FIELD_INTERACTIVE_STATE,
) -> Result<()> {
// We have a single display-only field showing "NetBird Login"
unsafe {
if !_pcpfs.is_null() {
*_pcpfs = CPFS_DISPLAY_IN_SELECTED_TILE;
}
if !_pcpfis.is_null() {
*_pcpfis = CPFIS_NONE;
}
}
Ok(())
}
fn GetStringValue(&self, _dwfieldid: u32) -> Result<PWSTR> {
let session = self.session.lock().unwrap();
let text = if let Some(ref s) = *session {
format!("NetBird: Logging in as {}", s.os_user)
} else {
"NetBird Login".to_string()
};
let wide: Vec<u16> = text.encode_utf16().chain(std::iter::once(0)).collect();
let ptr = unsafe {
let mem = windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
if mem.is_null() {
return Err(E_OUTOFMEMORY.into());
}
std::ptr::copy_nonoverlapping(wide.as_ptr(), mem, wide.len());
PWSTR(mem)
};
Ok(ptr)
}
fn GetBitmapValue(&self, _dwfieldid: u32) -> Result<HBITMAP> {
Err(E_NOTIMPL.into())
}
fn GetCheckboxValue(&self, _dwfieldid: u32, _pbchecked: *mut BOOL, _ppszlabel: *mut PWSTR) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetSubmitButtonValue(&self, _dwfieldid: u32, _pdwadjacentto: *mut u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetComboBoxValueCount(&self, _dwfieldid: u32, _pcitems: *mut u32, _pdwselecteditem: *mut u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetComboBoxValueAt(&self, _dwfieldid: u32, _dwitem: u32) -> Result<PWSTR> {
Err(E_NOTIMPL.into())
}
fn SetStringValue(&self, _dwfieldid: u32, _psz: &PCWSTR) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn SetCheckboxValue(&self, _dwfieldid: u32, _bchecked: BOOL) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn SetComboBoxSelectedValue(&self, _dwfieldid: u32, _dwselecteditem: u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn CommandLinkClicked(&self, _dwfieldid: u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetSerialization(
&self,
_pcpgsr: *mut CREDENTIAL_PROVIDER_GET_SERIALIZATION_RESPONSE,
_pcpcs: *mut CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
_ppszoptionalstatustext: *mut PWSTR,
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
) -> Result<()> {
let session = self.session.lock().unwrap();
let session_info = match &*session {
Some(s) => s.clone(),
None => {
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_NOT_FINISHED;
}
return Ok(());
}
};
// Consume the session with the agent
if let Err(e) = NamedPipeClient::consume_session(&session_info.session_id) {
log::error!("Failed to consume RDP session: {}", e);
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
}
return Ok(());
}
// Perform S4U logon
let username = &session_info.os_user;
let domain = if session_info.domain.is_empty() {
"."
} else {
&session_info.domain
};
match s4u::generate_s4u_token(username, domain) {
Ok(_token) => {
// In a full implementation, we would serialize the token into
// CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION format
// (KerbInteractiveLogon or MsV1_0InteractiveLogon structure).
//
// For the POC, we signal success. The actual serialization requires
// building the proper KERB_INTERACTIVE_LOGON or MSV1_0_INTERACTIVE_LOGON
// structure with the token handle, which is complex.
//
// TODO: Build proper credential serialization from S4U token
log::info!(
"S4U logon successful for {}\\{}, session {}",
domain,
username,
session_info.session_id
);
unsafe {
*_pcpgsr = CPGSR_RETURN_CREDENTIAL_FINISHED;
// Note: In production, pcpcs would be filled with the serialized credentials
}
Ok(())
}
Err(e) => {
log::error!("S4U logon failed for {}\\{}: {}", domain, username, e);
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
}
Ok(())
}
}
}
fn ReportResult(
&self,
_ntstatus: NTSTATUS,
_ntssubstatus: NTSTATUS,
_ppszoptionalstatustext: *mut PWSTR,
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
) -> Result<()> {
Ok(())
}
}

View File

@@ -0,0 +1,11 @@
use windows::core::GUID;
/// CLSID for the NetBird RDP Credential Provider.
/// Generated UUID: {7B3A8E5F-1C4D-4F8A-B2E6-9D0F3A7C5E1B}
pub const CLSID_NETBIRD_CREDENTIAL_PROVIDER: GUID = GUID::from_u128(
0x7B3A8E5F_1C4D_4F8A_B2E6_9D0F3A7C5E1B,
);
/// Registry path for credential providers.
pub const CREDENTIAL_PROVIDER_REGISTRY_PATH: &str =
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers";

View File

@@ -0,0 +1,309 @@
//! NetBird RDP Credential Provider for Windows.
//!
//! This DLL is a Windows Credential Provider that enables passwordless RDP access
//! to machines running the NetBird agent. It is loaded by Windows' LogonUI.exe
//! via COM when the login screen is displayed.
//!
//! ## How it works
//!
//! 1. The DLL is registered as a Credential Provider in the Windows registry
//! 2. When an RDP session begins, LogonUI loads the DLL
//! 3. The DLL queries the local NetBird agent via named pipe for pending sessions
//! 4. If a pending session exists for the connecting peer, the DLL:
//! - Shows a "NetBird Login" credential tile
//! - Performs S4U logon to create a Windows token without a password
//! - Returns the token to LogonUI for session creation
mod credential;
mod guid;
mod named_pipe_client;
mod provider;
mod s4u;
use guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
use provider::NetBirdCredentialProvider;
use std::sync::atomic::{AtomicU32, Ordering};
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::System::Com::*;
/// DLL reference count for COM lifecycle management.
static DLL_REF_COUNT: AtomicU32 = AtomicU32::new(0);
/// DLL module handle.
static mut DLL_MODULE: HMODULE = HMODULE(std::ptr::null_mut());
/// COM class factory for creating NetBirdCredentialProvider instances.
#[implement(IClassFactory)]
struct NetBirdClassFactory;
impl IClassFactory_Impl for NetBirdClassFactory_Impl {
fn CreateInstance(
&self,
_punkouter: Option<&IUnknown>,
riid: *const GUID,
ppvobject: *mut *mut std::ffi::c_void,
) -> Result<()> {
unsafe {
if !ppvobject.is_null() {
*ppvobject = std::ptr::null_mut();
}
}
if _punkouter.is_some() {
return Err(CLASS_E_NOAGGREGATION.into());
}
let provider = NetBirdCredentialProvider::new();
let unknown: IUnknown = provider.into();
unsafe {
unknown.query(riid, ppvobject).ok()
}
}
fn LockServer(&self, flock: BOOL) -> Result<()> {
if flock.as_bool() {
DLL_REF_COUNT.fetch_add(1, Ordering::SeqCst);
} else {
DLL_REF_COUNT.fetch_sub(1, Ordering::SeqCst);
}
Ok(())
}
}
/// DLL entry point.
#[no_mangle]
extern "system" fn DllMain(hinstance: HMODULE, reason: u32, _reserved: *mut std::ffi::c_void) -> BOOL {
const DLL_PROCESS_ATTACH: u32 = 1;
if reason == DLL_PROCESS_ATTACH {
unsafe {
DLL_MODULE = hinstance;
}
}
TRUE
}
/// COM entry point: returns a class factory for the requested CLSID.
#[no_mangle]
extern "system" fn DllGetClassObject(
rclsid: *const GUID,
riid: *const GUID,
ppv: *mut *mut std::ffi::c_void,
) -> HRESULT {
unsafe {
if ppv.is_null() {
return E_POINTER;
}
*ppv = std::ptr::null_mut();
if *rclsid != CLSID_NETBIRD_CREDENTIAL_PROVIDER {
return CLASS_E_CLASSNOTAVAILABLE;
}
let factory = NetBirdClassFactory;
let unknown: IUnknown = factory.into();
match unknown.query(riid, ppv) {
Ok(()) => S_OK,
Err(e) => e.code(),
}
}
}
/// COM entry point: indicates whether the DLL can be unloaded.
#[no_mangle]
extern "system" fn DllCanUnloadNow() -> HRESULT {
if DLL_REF_COUNT.load(Ordering::SeqCst) == 0 {
S_OK
} else {
S_FALSE
}
}
/// Self-registration: called by regsvr32 to register the credential provider.
#[no_mangle]
extern "system" fn DllRegisterServer() -> HRESULT {
match register_credential_provider(true) {
Ok(()) => S_OK,
Err(_) => E_FAIL,
}
}
/// Self-unregistration: called by regsvr32 /u to unregister the credential provider.
#[no_mangle]
extern "system" fn DllUnregisterServer() -> HRESULT {
match register_credential_provider(false) {
Ok(()) => S_OK,
Err(_) => E_FAIL,
}
}
fn register_credential_provider(register: bool) -> std::result::Result<(), Box<dyn std::error::Error>> {
use windows::Win32::System::Registry::*;
let clsid_str = format!("{{{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}",
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data1,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data2,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data3,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[0],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[1],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[2],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[3],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[4],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[5],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[6],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[7],
);
if register {
// Register under Credential Providers
let cp_key_path = format!(
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
clsid_str
);
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
let mut hkey = HKEY::default();
unsafe {
let result = RegCreateKeyExW(
HKEY_LOCAL_MACHINE,
PCWSTR(cp_key_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create credential provider registry key".into());
}
let value: Vec<u16> = "NetBird RDP Credential Provider"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let _ = RegSetValueExW(
hkey,
PCWSTR::null(),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
value.as_ptr() as *const u8,
value.len() * 2,
)),
);
let _ = RegCloseKey(hkey);
}
// Register CLSID in CLSID hive
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
let clsid_key_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let result = RegCreateKeyExW(
HKEY_CLASSES_ROOT,
PCWSTR(clsid_key_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create CLSID registry key".into());
}
let _ = RegCloseKey(hkey);
// InprocServer32 subkey
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
let result = RegCreateKeyExW(
HKEY_CLASSES_ROOT,
PCWSTR(inproc_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create InprocServer32 registry key".into());
}
// Set DLL path
let mut dll_path = [0u16; 260];
let len = windows::Win32::System::LibraryLoader::GetModuleFileNameW(
DLL_MODULE,
&mut dll_path,
);
if len > 0 {
let _ = RegSetValueExW(
hkey,
PCWSTR::null(),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
dll_path.as_ptr() as *const u8,
(len as usize + 1) * 2,
)),
);
}
// Set threading model
let threading: Vec<u16> = "Apartment"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let threading_name: Vec<u16> = "ThreadingModel"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let _ = RegSetValueExW(
hkey,
PCWSTR(threading_name.as_ptr()),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
threading.as_ptr() as *const u8,
threading.len() * 2,
)),
);
let _ = RegCloseKey(hkey);
}
} else {
// Unregister
let cp_key_path = format!(
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
clsid_str
);
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let _ = RegDeleteKeyW(HKEY_LOCAL_MACHINE, PCWSTR(cp_key_wide.as_ptr()));
}
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
let clsid_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(inproc_wide.as_ptr()));
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(clsid_wide.as_ptr()));
}
}
Ok(())
}

View File

@@ -0,0 +1,135 @@
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::time::Duration;
/// Named pipe path for communicating with the NetBird agent.
const PIPE_NAME: &str = r"\\.\pipe\netbird-rdp-auth";
/// Maximum response size from the agent.
const MAX_RESPONSE_SIZE: usize = 4096;
/// Timeout for named pipe operations.
const PIPE_TIMEOUT: Duration = Duration::from_secs(5);
/// Request sent to the NetBird agent via named pipe.
#[derive(Serialize)]
pub struct PipeRequest {
pub action: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub remote_ip: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
}
/// Response received from the NetBird agent via named pipe.
#[derive(Deserialize, Debug, Clone)]
pub struct PipeResponse {
pub found: bool,
#[serde(default)]
pub session_id: String,
#[serde(default)]
pub os_user: String,
#[serde(default)]
pub domain: String,
}
/// Client for communicating with the NetBird agent's named pipe server.
pub struct NamedPipeClient;
impl NamedPipeClient {
/// Query the NetBird agent for a pending RDP session matching the given remote IP.
pub fn query_pending(remote_ip: &str) -> Result<PipeResponse, PipeError> {
let request = PipeRequest {
action: "query_pending".to_string(),
remote_ip: Some(remote_ip.to_string()),
session_id: None,
};
Self::send_request(&request)
}
/// Tell the NetBird agent to consume (mark as used) a pending session.
pub fn consume_session(session_id: &str) -> Result<PipeResponse, PipeError> {
let request = PipeRequest {
action: "consume".to_string(),
remote_ip: None,
session_id: Some(session_id.to_string()),
};
Self::send_request(&request)
}
fn send_request(request: &PipeRequest) -> Result<PipeResponse, PipeError> {
let request_data =
serde_json::to_vec(request).map_err(|e| PipeError::Serialization(e.to_string()))?;
// Open named pipe (CreateFile in Windows)
let mut pipe = Self::open_pipe()?;
// Write request
pipe.write_all(&request_data)
.map_err(|e| PipeError::Write(e.to_string()))?;
// Shutdown write side to signal end of request
// For named pipes on Windows, we rely on the message boundary
pipe.flush()
.map_err(|e| PipeError::Write(e.to_string()))?;
// Read response
let mut response_data = vec![0u8; MAX_RESPONSE_SIZE];
let n = pipe
.read(&mut response_data)
.map_err(|e| PipeError::Read(e.to_string()))?;
let response: PipeResponse = serde_json::from_slice(&response_data[..n])
.map_err(|e| PipeError::Deserialization(e.to_string()))?;
Ok(response)
}
fn open_pipe() -> Result<std::fs::File, PipeError> {
// On Windows, named pipes are opened like files
use std::fs::OpenOptions;
// Try to open the pipe with a brief retry for PIPE_BUSY
for attempt in 0..3 {
match OpenOptions::new().read(true).write(true).open(PIPE_NAME) {
Ok(file) => return Ok(file),
Err(e) => {
if attempt < 2 {
std::thread::sleep(Duration::from_millis(100));
continue;
}
return Err(PipeError::Connect(format!(
"failed to open pipe {}: {}",
PIPE_NAME, e
)));
}
}
}
Err(PipeError::Connect("exhausted pipe connection attempts".to_string()))
}
}
/// Errors that can occur during named pipe communication.
#[derive(Debug)]
pub enum PipeError {
Connect(String),
Write(String),
Read(String),
Serialization(String),
Deserialization(String),
}
impl std::fmt::Display for PipeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PipeError::Connect(e) => write!(f, "pipe connect: {}", e),
PipeError::Write(e) => write!(f, "pipe write: {}", e),
PipeError::Read(e) => write!(f, "pipe read: {}", e),
PipeError::Serialization(e) => write!(f, "pipe serialization: {}", e),
PipeError::Deserialization(e) => write!(f, "pipe deserialization: {}", e),
}
}
}
impl std::error::Error for PipeError {}

View File

@@ -0,0 +1,270 @@
//! ICredentialProvider implementation.
//!
//! This is the main COM object that Windows' LogonUI.exe instantiates.
//! It determines whether to show a "NetBird Login" credential tile based on
//! whether the NetBird agent has a pending RDP session for the connecting peer.
use crate::credential::NetBirdCredential;
use crate::guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
use crate::named_pipe_client::NamedPipeClient;
use std::sync::Mutex;
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::Security::Credentials::*;
use windows::Win32::System::RemoteDesktop::*;
/// The NetBird Credential Provider, loaded by LogonUI.exe via COM.
#[implement(ICredentialProvider)]
pub struct NetBirdCredentialProvider {
/// The credential tile (if a pending session was found).
credential: Mutex<Option<ICredentialProviderCredential>>,
/// Whether this provider is active for the current usage scenario.
active: Mutex<bool>,
}
impl NetBirdCredentialProvider {
pub fn new() -> Self {
Self {
credential: Mutex::new(None),
active: Mutex::new(false),
}
}
}
impl ICredentialProvider_Impl for NetBirdCredentialProvider_Impl {
fn SetUsageScenario(
&self,
cpus: CREDENTIAL_PROVIDER_USAGE_SCENARIO,
_dwflags: u32,
) -> Result<()> {
let mut active = self.active.lock().unwrap();
match cpus {
CPUS_LOGON | CPUS_UNLOCK_WORKSTATION => {
// We activate for RDP logon and unlock scenarios
*active = true;
log::info!("NetBird CP activated for usage scenario {:?}", cpus.0);
Ok(())
}
_ => {
// Don't activate for credui or other scenarios
*active = false;
Err(E_NOTIMPL.into())
}
}
}
fn SetSerialization(
&self,
_pcpcs: *const CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn Advise(
&self,
_pcpe: Option<&ICredentialProviderEvents>,
_upadvisecontext: usize,
) -> Result<()> {
Ok(())
}
fn UnAdvise(&self) -> Result<()> {
Ok(())
}
fn GetFieldDescriptorCount(&self) -> Result<u32> {
// We have one field: a large text label showing "NetBird: Logging in as <user>"
Ok(1)
}
fn GetFieldDescriptorAt(
&self,
_dwindex: u32,
_ppcpfd: *mut *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR,
) -> Result<()> {
if _dwindex != 0 {
return Err(E_INVALIDARG.into());
}
let label = "NetBird Login";
let wide: Vec<u16> = label.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let desc = windows::Win32::System::Com::CoTaskMemAlloc(
std::mem::size_of::<CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR>(),
) as *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR;
if desc.is_null() {
return Err(E_OUTOFMEMORY.into());
}
let label_mem =
windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
if label_mem.is_null() {
windows::Win32::System::Com::CoTaskMemFree(Some(desc as *const _));
return Err(E_OUTOFMEMORY.into());
}
std::ptr::copy_nonoverlapping(wide.as_ptr(), label_mem, wide.len());
(*desc).dwFieldID = 0;
(*desc).cpft = CPFT_LARGE_TEXT;
(*desc).pszLabel = PWSTR(label_mem);
(*desc).guidFieldType = GUID::zeroed();
*_ppcpfd = desc;
}
Ok(())
}
fn GetCredentialCount(
&self,
_pdwcount: *mut u32,
_pdwdefault: *mut u32,
_pbautologinwithdefault: *mut BOOL,
) -> Result<()> {
let active = self.active.lock().unwrap();
if !*active {
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
return Ok(());
}
// Try to get the client IP of the current RDP session
let remote_ip = match get_rdp_client_ip() {
Some(ip) => ip,
None => {
log::debug!("NetBird CP: could not determine RDP client IP");
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
return Ok(());
}
};
// Query the NetBird agent for a pending session
match NamedPipeClient::query_pending(&remote_ip) {
Ok(response) if response.found => {
log::info!(
"NetBird CP: found pending session for {} -> {}",
remote_ip,
response.os_user
);
let cred = NetBirdCredential::new(remote_ip, response);
let icred: ICredentialProviderCredential = cred.into();
let mut credential = self.credential.lock().unwrap();
*credential = Some(icred);
unsafe {
*_pdwcount = 1;
*_pdwdefault = 0;
*_pbautologinwithdefault = TRUE; // auto-logon
}
}
Ok(_) => {
log::debug!("NetBird CP: no pending session for {}", remote_ip);
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
}
Err(e) => {
log::debug!("NetBird CP: pipe query failed: {}", e);
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
}
}
Ok(())
}
fn GetCredentialAt(
&self,
_dwindex: u32,
_ppcpc: *mut Option<ICredentialProviderCredential>,
) -> Result<()> {
if _dwindex != 0 {
return Err(E_INVALIDARG.into());
}
let credential = self.credential.lock().unwrap();
match &*credential {
Some(cred) => {
unsafe {
*_ppcpc = Some(cred.clone());
}
Ok(())
}
None => Err(E_UNEXPECTED.into()),
}
}
}
/// Get the IP address of the remote RDP client for the current session.
fn get_rdp_client_ip() -> Option<String> {
unsafe {
// Get the current session ID
let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
let mut session_id = 0u32;
if !windows::Win32::System::RemoteDesktop::ProcessIdToSessionId(process_id, &mut session_id)
.as_bool()
{
log::debug!("ProcessIdToSessionId failed");
return None;
}
// Query the client address
let mut buffer: *mut WTS_CLIENT_ADDRESS = std::ptr::null_mut();
let mut bytes_returned = 0u32;
let result = WTSQuerySessionInformationW(
WTS_CURRENT_SERVER_HANDLE,
session_id,
WTS_INFO_CLASS(14), // WTSClientAddress
&mut buffer as *mut _ as *mut *mut u16,
&mut bytes_returned,
);
if !result.as_bool() || buffer.is_null() {
log::debug!("WTSQuerySessionInformation(WTSClientAddress) failed");
return None;
}
let client_addr = &*buffer;
let ip = match client_addr.AddressFamily as u32 {
// AF_INET
2 => {
let addr = &client_addr.Address;
Some(format!("{}.{}.{}.{}", addr[2], addr[3], addr[4], addr[5]))
}
// AF_INET6
23 => {
// IPv6 - extract from Address bytes
let addr = &client_addr.Address;
Some(format!(
"{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9],
addr[10], addr[11], addr[12], addr[13], addr[14], addr[15], addr[16], addr[17]
))
}
_ => None,
};
WTSFreeMemory(buffer as *mut std::ffi::c_void);
ip
}
}

View File

@@ -0,0 +1,398 @@
//! S4U (Service for User) authentication for Windows.
//!
//! This module ports the S4U logon logic from the Go implementation at:
//! `client/ssh/server/executor_windows.go:generateS4UUserToken()`
//!
//! It creates Windows logon tokens without requiring a password, using the LSA
//! (Local Security Authority) S4U mechanism. This is the same approach used by
//! OpenSSH for Windows for public key authentication.
use std::ptr;
use windows::core::{PCSTR, PWSTR};
use windows::Win32::Foundation::{HANDLE, LUID, NTSTATUS, PSID};
use windows::Win32::Security::Authentication::Identity::{
LsaDeregisterLogonProcess, LsaFreeReturnBuffer, LsaLogonUser, LsaLookupAuthenticationPackage,
LsaRegisterLogonProcess, KERB_S4U_LOGON, MSV1_0_S4U_LOGON, MSV1_0_S4U_LOGON_FLAG_CHECK_LOGONHOURS,
SECURITY_LOGON_TYPE,
};
use windows::Win32::Security::{
QUOTA_LIMITS, TOKEN_SOURCE,
};
/// Status code for successful LSA operations.
const STATUS_SUCCESS: i32 = 0;
/// Network logon type (used for S4U).
const LOGON32_LOGON_NETWORK: SECURITY_LOGON_TYPE = SECURITY_LOGON_TYPE(3);
/// Kerberos S4U logon message type.
const KERB_S4U_LOGON_TYPE: u32 = 12;
/// MSV1_0 S4U logon message type.
const MSV1_0_S4U_LOGON_TYPE: u32 = 12;
/// Authentication package name for Kerberos.
const KERBEROS_PACKAGE: &str = "Kerberos";
/// Authentication package name for MSV1_0 (local users).
const MSV1_0_PACKAGE: &str = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0";
/// Result of a successful S4U logon.
pub struct S4UToken {
pub handle: HANDLE,
}
impl Drop for S4UToken {
fn drop(&mut self) {
if !self.handle.is_invalid() {
unsafe {
let _ = windows::Win32::Foundation::CloseHandle(self.handle);
}
}
}
}
/// Errors from S4U logon operations.
#[derive(Debug)]
pub enum S4UError {
LsaRegister(NTSTATUS),
LookupPackage(NTSTATUS),
LogonUser(NTSTATUS, i32),
AllocateLuid,
InvalidUsername(String),
Utf16Conversion(String),
}
impl std::fmt::Display for S4UError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
S4UError::LsaRegister(s) => write!(f, "LsaRegisterLogonProcess: 0x{:x}", s.0),
S4UError::LookupPackage(s) => write!(f, "LsaLookupAuthenticationPackage: 0x{:x}", s.0),
S4UError::LogonUser(s, sub) => {
write!(f, "LsaLogonUser S4U: NTSTATUS=0x{:x}, SubStatus=0x{:x}", s.0, sub)
}
S4UError::AllocateLuid => write!(f, "AllocateLocallyUniqueId failed"),
S4UError::InvalidUsername(u) => write!(f, "invalid username: {}", u),
S4UError::Utf16Conversion(s) => write!(f, "UTF-16 conversion: {}", s),
}
}
}
impl std::error::Error for S4UError {}
/// Generate a Windows logon token using S4U authentication.
///
/// This creates a token for the specified user without requiring a password.
/// The calling process must have SeTcbPrivilege (typically SYSTEM).
///
/// # Arguments
/// * `username` - The Windows username (without domain prefix)
/// * `domain` - The domain name ("." for local users)
///
/// # Returns
/// An `S4UToken` containing the Windows logon token handle.
pub fn generate_s4u_token(username: &str, domain: &str) -> Result<S4UToken, S4UError> {
if username.is_empty() {
return Err(S4UError::InvalidUsername("empty username".to_string()));
}
let is_local = is_local_user(domain);
// Initialize LSA connection
let lsa_handle = initialize_lsa_connection()?;
// Lookup authentication package
let auth_package_id = lookup_auth_package(lsa_handle, is_local)?;
// Perform S4U logon
let result = perform_s4u_logon(lsa_handle, auth_package_id, username, domain, is_local);
// Cleanup LSA connection
unsafe {
let _ = LsaDeregisterLogonProcess(lsa_handle);
}
result
}
fn is_local_user(domain: &str) -> bool {
domain.is_empty() || domain == "."
}
fn initialize_lsa_connection() -> Result<HANDLE, S4UError> {
let process_name = "NetBird\0";
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (process_name.len() - 1) as u16,
MaximumLength: process_name.len() as u16,
Buffer: windows::core::PSTR(process_name.as_ptr() as *mut u8),
};
let mut lsa_handle = HANDLE::default();
let mut mode = 0u32;
let status = unsafe {
LsaRegisterLogonProcess(&mut lsa_string, &mut lsa_handle, &mut mode)
};
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LsaRegister(status));
}
Ok(lsa_handle)
}
fn lookup_auth_package(lsa_handle: HANDLE, is_local: bool) -> Result<u32, S4UError> {
let package_name = if is_local { MSV1_0_PACKAGE } else { KERBEROS_PACKAGE };
let package_with_null = format!("{}\0", package_name);
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (package_with_null.len() - 1) as u16,
MaximumLength: package_with_null.len() as u16,
Buffer: windows::core::PSTR(package_with_null.as_ptr() as *mut u8),
};
let mut auth_package_id = 0u32;
let status = unsafe {
LsaLookupAuthenticationPackage(lsa_handle, &mut lsa_string, &mut auth_package_id)
};
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LookupPackage(status));
}
Ok(auth_package_id)
}
fn perform_s4u_logon(
lsa_handle: HANDLE,
auth_package_id: u32,
username: &str,
domain: &str,
is_local: bool,
) -> Result<S4UToken, S4UError> {
// Prepare token source
let mut source_name = [0u8; 8];
let name_bytes = b"netbird";
source_name[..name_bytes.len()].copy_from_slice(name_bytes);
let mut source_id = LUID::default();
let alloc_ok = unsafe {
windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime(
&mut std::mem::zeroed(),
);
// Use a simpler approach - just use the current time as a unique ID
source_id.LowPart = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
source_id.HighPart = std::process::id() as i32;
true
};
if !alloc_ok {
return Err(S4UError::AllocateLuid);
}
let token_source = TOKEN_SOURCE {
SourceName: source_name,
SourceIdentifier: source_id,
};
let origin_name_str = "netbird\0";
let mut origin_name = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (origin_name_str.len() - 1) as u16,
MaximumLength: origin_name_str.len() as u16,
Buffer: windows::core::PSTR(origin_name_str.as_ptr() as *mut u8),
};
// Build the logon info structure
let (logon_info_ptr, logon_info_size) = if is_local {
build_msv1_0_s4u_logon(username)?
} else {
build_kerb_s4u_logon(username, domain)?
};
let mut profile: *mut std::ffi::c_void = ptr::null_mut();
let mut profile_size = 0u32;
let mut logon_id = LUID::default();
let mut token = HANDLE::default();
let mut quotas = QUOTA_LIMITS::default();
let mut sub_status: i32 = 0;
let status = unsafe {
LsaLogonUser(
lsa_handle,
&mut origin_name,
LOGON32_LOGON_NETWORK,
auth_package_id,
logon_info_ptr as *const std::ffi::c_void,
logon_info_size as u32,
None, // local groups
&token_source,
&mut profile,
&mut profile_size,
&mut logon_id,
&mut token,
&mut quotas,
&mut sub_status,
)
};
// Free profile buffer if allocated
if !profile.is_null() {
unsafe {
let _ = LsaFreeReturnBuffer(profile);
}
}
// Free the logon info buffer
unsafe {
let layout = std::alloc::Layout::from_size_align_unchecked(logon_info_size, 8);
std::alloc::dealloc(logon_info_ptr as *mut u8, layout);
}
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LogonUser(status, sub_status));
}
Ok(S4UToken { handle: token })
}
/// Build MSV1_0_S4U_LOGON structure for local users.
fn build_msv1_0_s4u_logon(username: &str) -> Result<(*mut u8, usize), S4UError> {
let username_utf16: Vec<u16> = username.encode_utf16().chain(std::iter::once(0)).collect();
let domain_utf16: Vec<u16> = ".".encode_utf16().chain(std::iter::once(0)).collect();
let username_byte_size = username_utf16.len() * 2;
let domain_byte_size = domain_utf16.len() * 2;
// MSV1_0_S4U_LOGON structure:
// MessageType: u32 (4 bytes)
// Flags: u32 (4 bytes)
// UserPrincipalName: UNICODE_STRING (8 bytes on 32-bit, 16 bytes on 64-bit)
// DomainName: UNICODE_STRING
let struct_size = std::mem::size_of::<MSV1_0_S4U_LOGON_HEADER>();
let total_size = struct_size + username_byte_size + domain_byte_size;
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
if buffer.is_null() {
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
}
// For the POC, we'll set up the raw bytes manually since the windows-rs
// MSV1_0_S4U_LOGON structure layout may differ.
// This is a simplified version - in production, use proper FFI bindings.
unsafe {
// MessageType = MSV1_0_S4U_LOGON_TYPE (12)
*(buffer as *mut u32) = MSV1_0_S4U_LOGON_TYPE;
// Flags = 0
*((buffer as *mut u32).add(1)) = 0;
// Copy username UTF-16 after the structure
let username_offset = struct_size;
let username_dest = buffer.add(username_offset);
ptr::copy_nonoverlapping(
username_utf16.as_ptr() as *const u8,
username_dest,
username_byte_size,
);
// Copy domain UTF-16 after username
let domain_offset = username_offset + username_byte_size;
let domain_dest = buffer.add(domain_offset);
ptr::copy_nonoverlapping(
domain_utf16.as_ptr() as *const u8,
domain_dest,
domain_byte_size,
);
// Set UNICODE_STRING for UserPrincipalName (offset 8 on 64-bit)
// Length, MaximumLength, Buffer pointer
let upn_ptr = buffer.add(8) as *mut u16;
*upn_ptr = ((username_utf16.len() - 1) * 2) as u16; // Length (without null)
*(upn_ptr.add(1)) = (username_utf16.len() * 2) as u16; // MaximumLength
*((buffer.add(8 + 4)) as *mut *const u8) = username_dest; // Buffer
// Set UNICODE_STRING for DomainName
let dn_offset = 8 + std::mem::size_of::<UnicodeStringRaw>();
let dn_ptr = buffer.add(dn_offset) as *mut u16;
*dn_ptr = ((domain_utf16.len() - 1) * 2) as u16;
*(dn_ptr.add(1)) = (domain_utf16.len() * 2) as u16;
*((buffer.add(dn_offset + 4)) as *mut *const u8) = domain_dest;
}
Ok((buffer, total_size))
}
/// Build KERB_S4U_LOGON structure for domain users.
fn build_kerb_s4u_logon(username: &str, domain: &str) -> Result<(*mut u8, usize), S4UError> {
// Build UPN: username@domain
let upn = format!("{}@{}", username, domain);
let upn_utf16: Vec<u16> = upn.encode_utf16().chain(std::iter::once(0)).collect();
let upn_byte_size = upn_utf16.len() * 2;
let struct_size = std::mem::size_of::<KerbS4ULogonHeader>();
let total_size = struct_size + upn_byte_size;
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
if buffer.is_null() {
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
}
unsafe {
// MessageType = KERB_S4U_LOGON_TYPE (12)
*(buffer as *mut u32) = KERB_S4U_LOGON_TYPE;
// Flags = 0
*((buffer as *mut u32).add(1)) = 0;
// Copy UPN UTF-16 after the structure
let upn_offset = struct_size;
let upn_dest = buffer.add(upn_offset);
ptr::copy_nonoverlapping(
upn_utf16.as_ptr() as *const u8,
upn_dest,
upn_byte_size,
);
// Set UNICODE_STRING for ClientUpn (offset 8)
let upn_str_ptr = buffer.add(8) as *mut u16;
*upn_str_ptr = ((upn_utf16.len() - 1) * 2) as u16;
*(upn_str_ptr.add(1)) = (upn_utf16.len() * 2) as u16;
*((buffer.add(8 + 4)) as *mut *const u8) = upn_dest;
// ClientRealm is empty (zeroed)
}
Ok((buffer, total_size))
}
/// Raw UNICODE_STRING layout for size calculation.
#[repr(C)]
struct UnicodeStringRaw {
_length: u16,
_maximum_length: u16,
_buffer: *const u16,
}
/// Header size for MSV1_0_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
#[repr(C)]
struct MSV1_0_S4U_LOGON_HEADER {
_message_type: u32,
_flags: u32,
_user_principal_name: UnicodeStringRaw,
_domain_name: UnicodeStringRaw,
}
/// Header size for KERB_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
#[repr(C)]
struct KerbS4ULogonHeader {
_message_type: u32,
_flags: u32,
_client_upn: UnicodeStringRaw,
_client_realm: UnicodeStringRaw,
}

21
client/rdp/server/addr.go Normal file
View File

@@ -0,0 +1,21 @@
package server
import (
"fmt"
"net/netip"
)
// parseAddr parses a string into a netip.Addr, stripping any port or zone.
func parseAddr(s string) (netip.Addr, error) {
// Try as plain IP first
if addr, err := netip.ParseAddr(s); err == nil {
return addr, nil
}
// Try as IP:port
if addrPort, err := netip.ParseAddrPort(s); err == nil {
return addrPort.Addr(), nil
}
return netip.Addr{}, fmt.Errorf("invalid IP address: %s", s)
}

View File

@@ -0,0 +1,184 @@
package server
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net/netip"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
)
const (
// DefaultSessionTTL is the default time-to-live for pending RDP sessions.
DefaultSessionTTL = 60 * time.Second
// cleanupInterval is how often the store checks for expired sessions.
cleanupInterval = 10 * time.Second
// nonceLength is the length of the nonce in bytes.
nonceLength = 32
)
// PendingRDPSession represents an authorized but not yet consumed RDP session.
type PendingRDPSession struct {
SessionID string
PeerIP netip.Addr
OSUsername string
Domain string
JWTUserID string // for audit trail
Nonce string // replay protection
CreatedAt time.Time
ExpiresAt time.Time
consumed bool
}
// PendingStore manages pending RDP session entries with automatic expiration.
type PendingStore struct {
mu sync.RWMutex
sessions map[string]*PendingRDPSession // keyed by SessionID
nonces map[string]struct{} // seen nonces for replay protection
ttl time.Duration
}
// NewPendingStore creates a new pending session store with the given TTL.
func NewPendingStore(ttl time.Duration) *PendingStore {
if ttl <= 0 {
ttl = DefaultSessionTTL
}
return &PendingStore{
sessions: make(map[string]*PendingRDPSession),
nonces: make(map[string]struct{}),
ttl: ttl,
}
}
// Add creates a new pending RDP session and returns it.
func (ps *PendingStore) Add(peerIP netip.Addr, osUsername, domain, jwtUserID, nonce string) (*PendingRDPSession, error) {
ps.mu.Lock()
defer ps.mu.Unlock()
// Check nonce for replay protection
if _, seen := ps.nonces[nonce]; seen {
return nil, fmt.Errorf("duplicate nonce: replay detected")
}
ps.nonces[nonce] = struct{}{}
now := time.Now()
session := &PendingRDPSession{
SessionID: uuid.New().String(),
PeerIP: peerIP,
OSUsername: osUsername,
Domain: domain,
JWTUserID: jwtUserID,
Nonce: nonce,
CreatedAt: now,
ExpiresAt: now.Add(ps.ttl),
}
ps.sessions[session.SessionID] = session
log.Debugf("RDP pending session created: id=%s peer=%s user=%s domain=%s expires=%s",
session.SessionID, peerIP, osUsername, domain, session.ExpiresAt.Format(time.RFC3339))
return session, nil
}
// QueryByPeerIP finds the first non-consumed, non-expired pending session for the given peer IP.
func (ps *PendingStore) QueryByPeerIP(peerIP netip.Addr) (*PendingRDPSession, bool) {
ps.mu.RLock()
defer ps.mu.RUnlock()
now := time.Now()
for _, session := range ps.sessions {
if session.PeerIP == peerIP && !session.consumed && now.Before(session.ExpiresAt) {
return session, true
}
}
return nil, false
}
// Consume marks a session as consumed (single-use). Returns true if the session
// was found and successfully consumed, false if it was already consumed, expired, or not found.
func (ps *PendingStore) Consume(sessionID string) bool {
ps.mu.Lock()
defer ps.mu.Unlock()
session, exists := ps.sessions[sessionID]
if !exists {
return false
}
if session.consumed {
log.Debugf("RDP pending session already consumed: id=%s", sessionID)
return false
}
if time.Now().After(session.ExpiresAt) {
log.Debugf("RDP pending session expired: id=%s", sessionID)
return false
}
session.consumed = true
log.Debugf("RDP pending session consumed: id=%s peer=%s user=%s",
sessionID, session.PeerIP, session.OSUsername)
return true
}
// StartCleanup runs a background goroutine that periodically removes expired sessions.
func (ps *PendingStore) StartCleanup(ctx context.Context) {
go func() {
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
ps.cleanup()
}
}
}()
}
// cleanup removes expired and consumed sessions.
func (ps *PendingStore) cleanup() {
ps.mu.Lock()
defer ps.mu.Unlock()
now := time.Now()
for id, session := range ps.sessions {
if now.After(session.ExpiresAt) || session.consumed {
delete(ps.sessions, id)
delete(ps.nonces, session.Nonce)
}
}
}
// Count returns the number of active (non-expired, non-consumed) sessions.
func (ps *PendingStore) Count() int {
ps.mu.RLock()
defer ps.mu.RUnlock()
count := 0
now := time.Now()
for _, session := range ps.sessions {
if !session.consumed && now.Before(session.ExpiresAt) {
count++
}
}
return count
}
// GenerateNonce creates a cryptographically random nonce for replay protection.
func GenerateNonce() (string, error) {
b := make([]byte, nonceLength)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
return hex.EncodeToString(b), nil
}

View File

@@ -0,0 +1,268 @@
package server
import (
"context"
"net/netip"
"sync"
"testing"
"time"
)
func TestPendingStore_AddAndQuery(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-1")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
if session.SessionID == "" {
t.Fatal("expected non-empty session ID")
}
if session.PeerIP != peerIP {
t.Errorf("expected peer IP %s, got %s", peerIP, session.PeerIP)
}
if session.OSUsername != "admin" {
t.Errorf("expected username admin, got %s", session.OSUsername)
}
// Query should find the session
found, ok := store.QueryByPeerIP(peerIP)
if !ok {
t.Fatal("expected to find pending session")
}
if found.SessionID != session.SessionID {
t.Errorf("expected session %s, got %s", session.SessionID, found.SessionID)
}
// Query for different IP should not find anything
_, ok = store.QueryByPeerIP(netip.MustParseAddr("100.64.0.2"))
if ok {
t.Fatal("expected no session for different IP")
}
}
func TestPendingStore_Consume(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-2")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// First consume should succeed
if !store.Consume(session.SessionID) {
t.Fatal("expected first consume to succeed")
}
// Second consume should fail (already consumed)
if store.Consume(session.SessionID) {
t.Fatal("expected second consume to fail")
}
// Query should no longer find consumed session
_, ok := store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected consumed session to not be found by query")
}
}
func TestPendingStore_Expiry(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-3")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// Should be found immediately
_, ok := store.QueryByPeerIP(peerIP)
if !ok {
t.Fatal("expected to find session before expiry")
}
// Wait for expiry
time.Sleep(100 * time.Millisecond)
// Should not be found after expiry
_, ok = store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected session to be expired")
}
// Consume should also fail
if store.Consume(session.SessionID) {
t.Fatal("expected consume of expired session to fail")
}
}
func TestPendingStore_ReplayProtection(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
if err != nil {
t.Fatalf("first Add failed: %v", err)
}
// Same nonce should be rejected
_, err = store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
if err == nil {
t.Fatal("expected duplicate nonce to be rejected")
}
}
func TestPendingStore_Cleanup(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-cleanup")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
if store.Count() != 1 {
t.Fatalf("expected count 1, got %d", store.Count())
}
// Wait for expiry then trigger cleanup
time.Sleep(100 * time.Millisecond)
store.cleanup()
if store.Count() != 0 {
t.Fatalf("expected count 0 after cleanup, got %d", store.Count())
}
}
func TestPendingStore_CleanupBackground(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
store.StartCleanup(ctx)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-bg-cleanup")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// Wait for expiry + cleanup interval
time.Sleep(200 * time.Millisecond)
_, ok := store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected session to be cleaned up")
}
}
func TestPendingStore_ConcurrentAccess(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
ip := netip.AddrFrom4([4]byte{100, 64, byte(i / 256), byte(i % 256)})
nonce := "nonce-" + string(rune(i+'A'))
if i >= 26 {
nonce = "nonce-" + string(rune(i-26+'a'))
}
session, err := store.Add(ip, "admin", ".", "user", nonce)
if err != nil {
return // nonce collision in test is expected
}
store.QueryByPeerIP(ip)
store.Consume(session.SessionID)
}(i)
}
wg.Wait()
}
func TestPendingStore_MultipleSessions(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
ip1 := netip.MustParseAddr("100.64.0.1")
ip2 := netip.MustParseAddr("100.64.0.2")
s1, err := store.Add(ip1, "admin", ".", "user1", "nonce-a")
if err != nil {
t.Fatalf("Add s1 failed: %v", err)
}
s2, err := store.Add(ip2, "jdoe", "DOMAIN", "user2", "nonce-b")
if err != nil {
t.Fatalf("Add s2 failed: %v", err)
}
// Query each
found1, ok := store.QueryByPeerIP(ip1)
if !ok || found1.SessionID != s1.SessionID {
t.Fatal("expected to find s1")
}
found2, ok := store.QueryByPeerIP(ip2)
if !ok || found2.SessionID != s2.SessionID {
t.Fatal("expected to find s2")
}
if found2.Domain != "DOMAIN" {
t.Errorf("expected domain DOMAIN, got %s", found2.Domain)
}
if store.Count() != 2 {
t.Errorf("expected count 2, got %d", store.Count())
}
}
func TestGenerateNonce(t *testing.T) {
nonce1, err := GenerateNonce()
if err != nil {
t.Fatalf("GenerateNonce failed: %v", err)
}
nonce2, err := GenerateNonce()
if err != nil {
t.Fatalf("GenerateNonce failed: %v", err)
}
if len(nonce1) != nonceLength*2 { // hex encoding doubles the length
t.Errorf("expected nonce length %d, got %d", nonceLength*2, len(nonce1))
}
if nonce1 == nonce2 {
t.Error("expected unique nonces")
}
}
func TestParseWindowsUsername(t *testing.T) {
tests := []struct {
input string
expectedUser string
expectedDomain string
}{
{"admin", "admin", "."},
{"DOMAIN\\admin", "admin", "DOMAIN"},
{"admin@domain.com", "admin", "domain.com"},
{".\\localuser", "localuser", "."},
}
for _, tt := range tests {
user, domain := parseWindowsUsername(tt.input)
if user != tt.expectedUser {
t.Errorf("parseWindowsUsername(%q) user = %q, want %q", tt.input, user, tt.expectedUser)
}
if domain != tt.expectedDomain {
t.Errorf("parseWindowsUsername(%q) domain = %q, want %q", tt.input, domain, tt.expectedDomain)
}
}
}

View File

@@ -0,0 +1,19 @@
//go:build !windows
package server
import "context"
type stubPipeServer struct{}
func newPipeServer(_ *PendingStore) PipeServer {
return &stubPipeServer{}
}
func (s *stubPipeServer) Start(_ context.Context) error {
return nil
}
func (s *stubPipeServer) Stop() error {
return nil
}

View File

@@ -0,0 +1,164 @@
//go:build windows
package server
import (
"context"
"encoding/json"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/Microsoft/go-winio"
)
const (
// PipeName is the named pipe path used for IPC between the NetBird agent and
// the Credential Provider DLL.
PipeName = `\\.\pipe\netbird-rdp-auth`
// pipeSDDL restricts access to LOCAL_SYSTEM (SY) and Administrators (BA).
pipeSDDL = "D:P(A;;GA;;;SY)(A;;GA;;;BA)"
// maxPipeRequestSize is the maximum size of a pipe request in bytes.
maxPipeRequestSize = 4096
)
// windowsPipeServer implements the PipeServer interface for Windows.
type windowsPipeServer struct {
pending *PendingStore
listener net.Listener
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
func newPipeServer(pending *PendingStore) PipeServer {
return &windowsPipeServer{
pending: pending,
}
}
func (s *windowsPipeServer) Start(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
s.ctx, s.cancel = context.WithCancel(ctx)
cfg := &winio.PipeConfig{
SecurityDescriptor: pipeSDDL,
}
listener, err := winio.ListenPipe(PipeName, cfg)
if err != nil {
return err
}
s.listener = listener
go s.acceptLoop()
log.Infof("RDP named pipe server started on %s", PipeName)
return nil
}
func (s *windowsPipeServer) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return nil
}
func (s *windowsPipeServer) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
if s.ctx.Err() != nil {
return
}
log.Debugf("RDP pipe accept error: %v", err)
continue
}
go s.handlePipeConnection(conn)
}
}
func (s *windowsPipeServer) handlePipeConnection(conn net.Conn) {
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("RDP pipe close: %v", err)
}
}()
data, err := io.ReadAll(io.LimitReader(conn, maxPipeRequestSize))
if err != nil {
log.Debugf("RDP pipe read: %v", err)
return
}
var req PipeRequest
if err := json.Unmarshal(data, &req); err != nil {
log.Debugf("RDP pipe unmarshal: %v", err)
return
}
var resp PipeResponse
switch req.Action {
case PipeActionQuery:
resp = s.handleQuery(req.RemoteIP)
case PipeActionConsume:
resp = s.handleConsume(req.SessionID)
default:
log.Debugf("RDP pipe unknown action: %s", req.Action)
return
}
respData, err := json.Marshal(resp)
if err != nil {
log.Debugf("RDP pipe marshal response: %v", err)
return
}
if _, err := conn.Write(respData); err != nil {
log.Debugf("RDP pipe write response: %v", err)
}
}
func (s *windowsPipeServer) handleQuery(remoteIP string) PipeResponse {
peerIP, err := parseAddr(remoteIP)
if err != nil {
log.Debugf("RDP pipe invalid remote IP: %s", remoteIP)
return PipeResponse{Found: false}
}
session, found := s.pending.QueryByPeerIP(peerIP)
if !found {
return PipeResponse{Found: false}
}
return PipeResponse{
Found: true,
SessionID: session.SessionID,
OSUser: session.OSUsername,
Domain: session.Domain,
}
}
func (s *windowsPipeServer) handleConsume(sessionID string) PipeResponse {
if s.pending.Consume(sessionID) {
return PipeResponse{Found: true, SessionID: sessionID}
}
return PipeResponse{Found: false}
}

View File

@@ -0,0 +1,48 @@
package server
// AuthRequest is the sideband authorization request sent by the connecting peer
// to the target peer's RDP auth server over the WireGuard tunnel.
type AuthRequest struct {
JWTToken string `json:"jwt_token"`
RequestedUser string `json:"requested_user"`
ClientPeerIP string `json:"client_peer_ip"`
Nonce string `json:"nonce"`
}
// AuthResponse is the sideband authorization response sent by the target peer
// back to the connecting peer.
type AuthResponse struct {
Status string `json:"status"` // "authorized" or "denied"
SessionID string `json:"session_id,omitempty"`
ExpiresAt int64 `json:"expires_at,omitempty"` // unix timestamp
OSUser string `json:"os_user,omitempty"`
Reason string `json:"reason,omitempty"`
}
// PipeRequest is the IPC request from the Credential Provider DLL to the NetBird agent
// via the named pipe.
type PipeRequest struct {
Action string `json:"action"` // "query_pending" or "consume"
RemoteIP string `json:"remote_ip"` // connecting peer's WG IP
SessionID string `json:"session_id,omitempty"` // for consume action
}
// PipeResponse is the IPC response from the NetBird agent to the Credential Provider DLL.
type PipeResponse struct {
Found bool `json:"found"`
SessionID string `json:"session_id,omitempty"`
OSUser string `json:"os_user,omitempty"`
Domain string `json:"domain,omitempty"`
}
const (
// StatusAuthorized indicates the RDP session was authorized.
StatusAuthorized = "authorized"
// StatusDenied indicates the RDP session was denied.
StatusDenied = "denied"
// PipeActionQuery queries for a pending session by remote IP.
PipeActionQuery = "query_pending"
// PipeActionConsume marks a pending session as consumed.
PipeActionConsume = "consume"
)

286
client/rdp/server/server.go Normal file
View File

@@ -0,0 +1,286 @@
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
// InternalRDPAuthPort is the internal port the sideband auth server listens on.
InternalRDPAuthPort = 22023
// DefaultRDPAuthPort is the external port on the WireGuard interface (DNAT target).
DefaultRDPAuthPort = 3390
// maxRequestSize is the maximum size of an auth request in bytes.
maxRequestSize = 64 * 1024
// connectionTimeout is the timeout for a single auth connection.
connectionTimeout = 30 * time.Second
)
// JWTValidator validates JWT tokens and extracts user identity.
type JWTValidator interface {
ValidateAndExtract(token string) (userID string, err error)
}
// Authorizer checks if a user is authorized for RDP access.
type Authorizer interface {
Authorize(jwtUserID, osUsername string) (string, error)
}
// Server is the sideband RDP authorization server that listens on the WireGuard interface.
type Server struct {
listener net.Listener
pending *PendingStore
pipeServer PipeServer
jwtValidator JWTValidator
authorizer Authorizer
networkAddr netip.Prefix // WireGuard network for source IP validation
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
// PipeServer is the interface for the named pipe IPC server (platform-specific).
type PipeServer interface {
Start(ctx context.Context) error
Stop() error
}
// Config holds the configuration for the RDP auth server.
type Config struct {
JWTValidator JWTValidator
Authorizer Authorizer
NetworkAddr netip.Prefix
SessionTTL time.Duration
}
// New creates a new RDP sideband auth server.
func New(cfg *Config) *Server {
ttl := cfg.SessionTTL
if ttl <= 0 {
ttl = DefaultSessionTTL
}
pending := NewPendingStore(ttl)
return &Server{
pending: pending,
pipeServer: newPipeServer(pending),
jwtValidator: cfg.JWTValidator,
authorizer: cfg.Authorizer,
networkAddr: cfg.NetworkAddr,
}
}
// Start begins listening for sideband auth requests on the given address.
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return errors.New("RDP auth server already running")
}
s.ctx, s.cancel = context.WithCancel(ctx)
listenAddr := net.TCPAddrFromAddrPort(addr)
listener, err := net.ListenTCP("tcp", listenAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
s.listener = listener
s.pending.StartCleanup(s.ctx)
if s.pipeServer != nil {
if err := s.pipeServer.Start(s.ctx); err != nil {
log.Warnf("failed to start RDP named pipe server: %v", err)
}
}
go s.acceptLoop()
log.Infof("RDP sideband auth server started on %s", addr)
return nil
}
// Stop shuts down the server and cleans up resources.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
}
if s.pipeServer != nil {
if err := s.pipeServer.Stop(); err != nil {
log.Warnf("failed to stop RDP named pipe server: %v", err)
}
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
if err != nil {
return fmt.Errorf("close listener: %w", err)
}
}
log.Info("RDP sideband auth server stopped")
return nil
}
// GetPendingStore returns the pending session store (for testing/named pipe access).
func (s *Server) GetPendingStore() *PendingStore {
return s.pending
}
func (s *Server) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
if s.ctx.Err() != nil {
return
}
log.Debugf("RDP auth accept error: %v", err)
continue
}
go s.handleConnection(conn)
}
}
func (s *Server) handleConnection(conn net.Conn) {
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("RDP auth close connection: %v", err)
}
}()
if err := conn.SetDeadline(time.Now().Add(connectionTimeout)); err != nil {
log.Debugf("RDP auth set deadline: %v", err)
return
}
// Validate source IP is from WireGuard network
remoteAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
if err != nil {
log.Debugf("RDP auth parse remote addr: %v", err)
return
}
if !s.networkAddr.Contains(remoteAddr.Addr()) {
log.Warnf("RDP auth rejected connection from non-WG address: %s", remoteAddr.Addr())
return
}
// Read request
data, err := io.ReadAll(io.LimitReader(conn, maxRequestSize))
if err != nil {
log.Debugf("RDP auth read request: %v", err)
return
}
var req AuthRequest
if err := json.Unmarshal(data, &req); err != nil {
log.Debugf("RDP auth unmarshal request: %v", err)
s.sendResponse(conn, &AuthResponse{Status: StatusDenied, Reason: "invalid request format"})
return
}
response := s.processAuthRequest(remoteAddr.Addr(), &req)
s.sendResponse(conn, response)
}
func (s *Server) processAuthRequest(peerIP netip.Addr, req *AuthRequest) *AuthResponse {
// Validate JWT
if s.jwtValidator == nil {
// No JWT validation configured - for POC, accept all requests from WG peers
log.Warnf("RDP auth: no JWT validator configured, accepting request from %s", peerIP)
return s.createSession(peerIP, req, "no-jwt-validation")
}
userID, err := s.jwtValidator.ValidateAndExtract(req.JWTToken)
if err != nil {
log.Warnf("RDP auth JWT validation failed for %s: %v", peerIP, err)
return &AuthResponse{Status: StatusDenied, Reason: "JWT validation failed"}
}
// Check authorization
if s.authorizer != nil {
if _, err := s.authorizer.Authorize(userID, req.RequestedUser); err != nil {
log.Warnf("RDP auth denied for user %s -> %s: %v", userID, req.RequestedUser, err)
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
}
}
return s.createSession(peerIP, req, userID)
}
func (s *Server) createSession(peerIP netip.Addr, req *AuthRequest, jwtUserID string) *AuthResponse {
// Parse domain from requested user (DOMAIN\user or user@domain)
osUser, domain := parseWindowsUsername(req.RequestedUser)
session, err := s.pending.Add(peerIP, osUser, domain, jwtUserID, req.Nonce)
if err != nil {
log.Warnf("RDP auth create session failed: %v", err)
return &AuthResponse{Status: StatusDenied, Reason: err.Error()}
}
return &AuthResponse{
Status: StatusAuthorized,
SessionID: session.SessionID,
ExpiresAt: session.ExpiresAt.Unix(),
OSUser: session.OSUsername,
}
}
func (s *Server) sendResponse(conn net.Conn, resp *AuthResponse) {
data, err := json.Marshal(resp)
if err != nil {
log.Debugf("RDP auth marshal response: %v", err)
return
}
if _, err := conn.Write(data); err != nil {
log.Debugf("RDP auth write response: %v", err)
}
}
// parseWindowsUsername extracts username and domain from Windows username formats.
// Supports DOMAIN\username, username@domain, and plain username.
func parseWindowsUsername(fullUsername string) (username, domain string) {
for i := len(fullUsername) - 1; i >= 0; i-- {
if fullUsername[i] == '\\' {
return fullUsername[i+1:], fullUsername[:i]
}
}
if idx := indexOf(fullUsername, '@'); idx != -1 {
return fullUsername[:idx], fullUsername[idx+1:]
}
return fullUsername, "."
}
func indexOf(s string, c byte) int {
for i := 0; i < len(s); i++ {
if s[i] == c {
return i
}
}
return -1
}