mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Implement sideband authorization and credential provider architecture for passwordless RDP access to Windows peers via NetBird. Go components: - Sideband RDP auth server (TCP on WG interface, port 3390/22023) - Pending session store with TTL expiry and replay protection - Named pipe IPC server (\\.\pipe\netbird-rdp-auth) for credential provider - Sideband client for connecting peer to request authorization - CLI command `netbird rdp [user@]host` with JWT auth flow - Engine integration with DNAT port redirection Rust credential provider DLL (client/rdp/credprov/): - COM DLL implementing ICredentialProvider + ICredentialProviderCredential - Loaded by Windows LogonUI.exe at the RDP login screen - Queries NetBird agent via named pipe for pending sessions - Performs S4U logon (LsaLogonUser) for passwordless Windows token creation - Self-registration via regsvr32 (DllRegisterServer/DllUnregisterServer) https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
310 lines
9.7 KiB
Rust
310 lines
9.7 KiB
Rust
//! NetBird RDP Credential Provider for Windows.
|
|
//!
|
|
//! This DLL is a Windows Credential Provider that enables passwordless RDP access
|
|
//! to machines running the NetBird agent. It is loaded by Windows' LogonUI.exe
|
|
//! via COM when the login screen is displayed.
|
|
//!
|
|
//! ## How it works
|
|
//!
|
|
//! 1. The DLL is registered as a Credential Provider in the Windows registry
|
|
//! 2. When an RDP session begins, LogonUI loads the DLL
|
|
//! 3. The DLL queries the local NetBird agent via named pipe for pending sessions
|
|
//! 4. If a pending session exists for the connecting peer, the DLL:
|
|
//! - Shows a "NetBird Login" credential tile
|
|
//! - Performs S4U logon to create a Windows token without a password
|
|
//! - Returns the token to LogonUI for session creation
|
|
|
|
mod credential;
|
|
mod guid;
|
|
mod named_pipe_client;
|
|
mod provider;
|
|
mod s4u;
|
|
|
|
use guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
|
|
use provider::NetBirdCredentialProvider;
|
|
use std::sync::atomic::{AtomicU32, Ordering};
|
|
use windows::core::*;
|
|
use windows::Win32::Foundation::*;
|
|
use windows::Win32::System::Com::*;
|
|
|
|
/// DLL reference count for COM lifecycle management.
|
|
static DLL_REF_COUNT: AtomicU32 = AtomicU32::new(0);
|
|
|
|
/// DLL module handle.
|
|
static mut DLL_MODULE: HMODULE = HMODULE(std::ptr::null_mut());
|
|
|
|
/// COM class factory for creating NetBirdCredentialProvider instances.
|
|
#[implement(IClassFactory)]
|
|
struct NetBirdClassFactory;
|
|
|
|
impl IClassFactory_Impl for NetBirdClassFactory_Impl {
|
|
fn CreateInstance(
|
|
&self,
|
|
_punkouter: Option<&IUnknown>,
|
|
riid: *const GUID,
|
|
ppvobject: *mut *mut std::ffi::c_void,
|
|
) -> Result<()> {
|
|
unsafe {
|
|
if !ppvobject.is_null() {
|
|
*ppvobject = std::ptr::null_mut();
|
|
}
|
|
}
|
|
|
|
if _punkouter.is_some() {
|
|
return Err(CLASS_E_NOAGGREGATION.into());
|
|
}
|
|
|
|
let provider = NetBirdCredentialProvider::new();
|
|
let unknown: IUnknown = provider.into();
|
|
|
|
unsafe {
|
|
unknown.query(riid, ppvobject).ok()
|
|
}
|
|
}
|
|
|
|
fn LockServer(&self, flock: BOOL) -> Result<()> {
|
|
if flock.as_bool() {
|
|
DLL_REF_COUNT.fetch_add(1, Ordering::SeqCst);
|
|
} else {
|
|
DLL_REF_COUNT.fetch_sub(1, Ordering::SeqCst);
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// DLL entry point.
|
|
#[no_mangle]
|
|
extern "system" fn DllMain(hinstance: HMODULE, reason: u32, _reserved: *mut std::ffi::c_void) -> BOOL {
|
|
const DLL_PROCESS_ATTACH: u32 = 1;
|
|
|
|
if reason == DLL_PROCESS_ATTACH {
|
|
unsafe {
|
|
DLL_MODULE = hinstance;
|
|
}
|
|
}
|
|
|
|
TRUE
|
|
}
|
|
|
|
/// COM entry point: returns a class factory for the requested CLSID.
|
|
#[no_mangle]
|
|
extern "system" fn DllGetClassObject(
|
|
rclsid: *const GUID,
|
|
riid: *const GUID,
|
|
ppv: *mut *mut std::ffi::c_void,
|
|
) -> HRESULT {
|
|
unsafe {
|
|
if ppv.is_null() {
|
|
return E_POINTER;
|
|
}
|
|
*ppv = std::ptr::null_mut();
|
|
|
|
if *rclsid != CLSID_NETBIRD_CREDENTIAL_PROVIDER {
|
|
return CLASS_E_CLASSNOTAVAILABLE;
|
|
}
|
|
|
|
let factory = NetBirdClassFactory;
|
|
let unknown: IUnknown = factory.into();
|
|
|
|
match unknown.query(riid, ppv) {
|
|
Ok(()) => S_OK,
|
|
Err(e) => e.code(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// COM entry point: indicates whether the DLL can be unloaded.
|
|
#[no_mangle]
|
|
extern "system" fn DllCanUnloadNow() -> HRESULT {
|
|
if DLL_REF_COUNT.load(Ordering::SeqCst) == 0 {
|
|
S_OK
|
|
} else {
|
|
S_FALSE
|
|
}
|
|
}
|
|
|
|
/// Self-registration: called by regsvr32 to register the credential provider.
|
|
#[no_mangle]
|
|
extern "system" fn DllRegisterServer() -> HRESULT {
|
|
match register_credential_provider(true) {
|
|
Ok(()) => S_OK,
|
|
Err(_) => E_FAIL,
|
|
}
|
|
}
|
|
|
|
/// Self-unregistration: called by regsvr32 /u to unregister the credential provider.
|
|
#[no_mangle]
|
|
extern "system" fn DllUnregisterServer() -> HRESULT {
|
|
match register_credential_provider(false) {
|
|
Ok(()) => S_OK,
|
|
Err(_) => E_FAIL,
|
|
}
|
|
}
|
|
|
|
fn register_credential_provider(register: bool) -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|
use windows::Win32::System::Registry::*;
|
|
|
|
let clsid_str = format!("{{{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}",
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data1,
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data2,
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data3,
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[0],
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[1],
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[2],
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[3],
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[4],
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[5],
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[6],
|
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[7],
|
|
);
|
|
|
|
if register {
|
|
// Register under Credential Providers
|
|
let cp_key_path = format!(
|
|
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
|
|
clsid_str
|
|
);
|
|
|
|
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
|
let mut hkey = HKEY::default();
|
|
|
|
unsafe {
|
|
let result = RegCreateKeyExW(
|
|
HKEY_LOCAL_MACHINE,
|
|
PCWSTR(cp_key_wide.as_ptr()),
|
|
0,
|
|
PCWSTR::null(),
|
|
REG_OPTION_NON_VOLATILE,
|
|
KEY_WRITE,
|
|
None,
|
|
&mut hkey,
|
|
None,
|
|
);
|
|
if result.is_err() {
|
|
return Err("Failed to create credential provider registry key".into());
|
|
}
|
|
|
|
let value: Vec<u16> = "NetBird RDP Credential Provider"
|
|
.encode_utf16()
|
|
.chain(std::iter::once(0))
|
|
.collect();
|
|
let _ = RegSetValueExW(
|
|
hkey,
|
|
PCWSTR::null(),
|
|
0,
|
|
REG_SZ,
|
|
Some(std::slice::from_raw_parts(
|
|
value.as_ptr() as *const u8,
|
|
value.len() * 2,
|
|
)),
|
|
);
|
|
let _ = RegCloseKey(hkey);
|
|
}
|
|
|
|
// Register CLSID in CLSID hive
|
|
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
|
|
let clsid_key_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
|
|
|
unsafe {
|
|
let result = RegCreateKeyExW(
|
|
HKEY_CLASSES_ROOT,
|
|
PCWSTR(clsid_key_wide.as_ptr()),
|
|
0,
|
|
PCWSTR::null(),
|
|
REG_OPTION_NON_VOLATILE,
|
|
KEY_WRITE,
|
|
None,
|
|
&mut hkey,
|
|
None,
|
|
);
|
|
if result.is_err() {
|
|
return Err("Failed to create CLSID registry key".into());
|
|
}
|
|
let _ = RegCloseKey(hkey);
|
|
|
|
// InprocServer32 subkey
|
|
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
|
|
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
|
|
|
|
let result = RegCreateKeyExW(
|
|
HKEY_CLASSES_ROOT,
|
|
PCWSTR(inproc_wide.as_ptr()),
|
|
0,
|
|
PCWSTR::null(),
|
|
REG_OPTION_NON_VOLATILE,
|
|
KEY_WRITE,
|
|
None,
|
|
&mut hkey,
|
|
None,
|
|
);
|
|
if result.is_err() {
|
|
return Err("Failed to create InprocServer32 registry key".into());
|
|
}
|
|
|
|
// Set DLL path
|
|
let mut dll_path = [0u16; 260];
|
|
let len = windows::Win32::System::LibraryLoader::GetModuleFileNameW(
|
|
DLL_MODULE,
|
|
&mut dll_path,
|
|
);
|
|
if len > 0 {
|
|
let _ = RegSetValueExW(
|
|
hkey,
|
|
PCWSTR::null(),
|
|
0,
|
|
REG_SZ,
|
|
Some(std::slice::from_raw_parts(
|
|
dll_path.as_ptr() as *const u8,
|
|
(len as usize + 1) * 2,
|
|
)),
|
|
);
|
|
}
|
|
|
|
// Set threading model
|
|
let threading: Vec<u16> = "Apartment"
|
|
.encode_utf16()
|
|
.chain(std::iter::once(0))
|
|
.collect();
|
|
let threading_name: Vec<u16> = "ThreadingModel"
|
|
.encode_utf16()
|
|
.chain(std::iter::once(0))
|
|
.collect();
|
|
let _ = RegSetValueExW(
|
|
hkey,
|
|
PCWSTR(threading_name.as_ptr()),
|
|
0,
|
|
REG_SZ,
|
|
Some(std::slice::from_raw_parts(
|
|
threading.as_ptr() as *const u8,
|
|
threading.len() * 2,
|
|
)),
|
|
);
|
|
|
|
let _ = RegCloseKey(hkey);
|
|
}
|
|
} else {
|
|
// Unregister
|
|
let cp_key_path = format!(
|
|
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
|
|
clsid_str
|
|
);
|
|
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
|
|
|
unsafe {
|
|
let _ = RegDeleteKeyW(HKEY_LOCAL_MACHINE, PCWSTR(cp_key_wide.as_ptr()));
|
|
}
|
|
|
|
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
|
|
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
|
|
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
|
|
let clsid_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
|
|
|
unsafe {
|
|
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(inproc_wide.as_ptr()));
|
|
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(clsid_wide.as_ptr()));
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|