mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 23:06:38 +00:00
[client] Add RDP token passthrough for passwordless Windows Remote Desktop
Implement sideband authorization and credential provider architecture for passwordless RDP access to Windows peers via NetBird. Go components: - Sideband RDP auth server (TCP on WG interface, port 3390/22023) - Pending session store with TTL expiry and replay protection - Named pipe IPC server (\\.\pipe\netbird-rdp-auth) for credential provider - Sideband client for connecting peer to request authorization - CLI command `netbird rdp [user@]host` with JWT auth flow - Engine integration with DNAT port redirection Rust credential provider DLL (client/rdp/credprov/): - COM DLL implementing ICredentialProvider + ICredentialProviderCredential - Loaded by Windows LogonUI.exe at the RDP login screen - Queries NetBird agent via named pipe for pending sessions - Performs S4U logon (LsaLogonUser) for passwordless Windows token creation - Self-registration via regsvr32 (DllRegisterServer/DllUnregisterServer) https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
This commit is contained in:
269
client/cmd/rdp.go
Normal file
269
client/cmd/rdp.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
rdpclient "github.com/netbirdio/netbird/client/rdp/client"
|
||||
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
rdpUsername string
|
||||
rdpHost string
|
||||
rdpNoBrowser bool
|
||||
rdpNoCache bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
rdpCmd.PersistentFlags().StringVarP(&rdpUsername, "user", "u", "", "Windows username on remote peer")
|
||||
rdpCmd.PersistentFlags().BoolVar(&rdpNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
rdpCmd.PersistentFlags().BoolVar(&rdpNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
}
|
||||
|
||||
var rdpCmd = &cobra.Command{
|
||||
Use: "rdp [flags] [user@]host",
|
||||
Short: "Connect to a NetBird peer via RDP (passwordless)",
|
||||
Long: `Connect to a NetBird peer using Remote Desktop Protocol with token-based
|
||||
passwordless authentication. The target peer must have RDP passthrough enabled.
|
||||
|
||||
This command:
|
||||
1. Obtains a JWT token via OIDC authentication
|
||||
2. Sends the token to the target peer's sideband auth service
|
||||
3. If authorized, launches mstsc.exe to connect
|
||||
|
||||
Examples:
|
||||
netbird rdp peer-hostname
|
||||
netbird rdp administrator@peer-hostname
|
||||
netbird rdp --user admin peer-hostname`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: rdpFn,
|
||||
}
|
||||
|
||||
func rdpFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
// Parse user@host
|
||||
if err := parseRDPHostArg(args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
rdpCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runRDP(rdpCtx, cmd); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-rdpCtx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-rdpCtx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRDPHostArg(arg string) error {
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.New("invalid user@host format")
|
||||
}
|
||||
if rdpUsername == "" {
|
||||
rdpUsername = parts[0]
|
||||
}
|
||||
rdpHost = parts[1]
|
||||
} else {
|
||||
rdpHost = arg
|
||||
}
|
||||
|
||||
if rdpUsername == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
rdpUsername = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
rdpUsername = currentUser.Username
|
||||
} else {
|
||||
rdpUsername = "Administrator"
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runRDP(ctx context.Context, cmd *cobra.Command) error {
|
||||
// Connect to daemon
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(grpcConn)
|
||||
|
||||
// Resolve peer IP
|
||||
peerIP, err := resolvePeerIP(ctx, daemonClient, rdpHost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve peer %s: %w", rdpHost, err)
|
||||
}
|
||||
|
||||
cmd.Printf("Connecting to %s@%s (%s)...\n", rdpUsername, rdpHost, peerIP)
|
||||
|
||||
// Obtain JWT token
|
||||
hint := profilemanager.GetLoginHint()
|
||||
var browserOpener func(string) error
|
||||
if !rdpNoBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
jwtToken, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !rdpNoCache, hint, browserOpener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("JWT authentication: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("JWT authentication successful")
|
||||
cmd.Println("Authenticated. Requesting RDP access...")
|
||||
|
||||
// Generate nonce for replay protection
|
||||
nonce, err := rdpserver.GenerateNonce()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Send sideband auth request
|
||||
authClient := rdpclient.New()
|
||||
authAddr := net.JoinHostPort(peerIP, fmt.Sprintf("%d", rdpserver.DefaultRDPAuthPort))
|
||||
|
||||
resp, err := authClient.RequestAuth(ctx, authAddr, &rdpserver.AuthRequest{
|
||||
JWTToken: jwtToken,
|
||||
RequestedUser: rdpUsername,
|
||||
ClientPeerIP: "", // will be filled by the server from the connection
|
||||
Nonce: nonce,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to authorize RDP session with %s\n", rdpHost)
|
||||
cmd.Printf("\nTroubleshooting:\n")
|
||||
cmd.Printf(" 1. Check connectivity: netbird status -d\n")
|
||||
cmd.Printf(" 2. Verify RDP passthrough is enabled on the target peer\n")
|
||||
return fmt.Errorf("sideband auth: %w", err)
|
||||
}
|
||||
|
||||
if resp.Status != rdpserver.StatusAuthorized {
|
||||
return fmt.Errorf("RDP access denied: %s", resp.Reason)
|
||||
}
|
||||
|
||||
cmd.Printf("RDP access authorized (session: %s, user: %s)\n", resp.SessionID, resp.OSUser)
|
||||
cmd.Printf("Launching Remote Desktop client...\n")
|
||||
|
||||
// Launch mstsc.exe (platform-specific)
|
||||
if err := launchRDPClient(peerIP); err != nil {
|
||||
return fmt.Errorf("launch RDP client: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolvePeerIP resolves a peer hostname/FQDN to its WireGuard IP address
|
||||
// by querying the daemon for the current peer status.
|
||||
func resolvePeerIP(ctx context.Context, client proto.DaemonServiceClient, peerAddress string) (string, error) {
|
||||
statusResp, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get daemon status: %w", err)
|
||||
}
|
||||
|
||||
if statusResp.GetFullStatus() == nil {
|
||||
return "", errors.New("daemon returned empty status")
|
||||
}
|
||||
|
||||
for _, peer := range statusResp.GetFullStatus().GetPeers() {
|
||||
if matchesPeer(peer, peerAddress) {
|
||||
ip := peer.GetIP()
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
// Strip CIDR suffix if present
|
||||
if idx := strings.Index(ip, "/"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If not found as a peer name, try as a direct IP
|
||||
if addr, err := net.ResolveIPAddr("ip", peerAddress); err == nil {
|
||||
return addr.String(), nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("peer %q not found in network", peerAddress)
|
||||
}
|
||||
|
||||
func matchesPeer(peer *proto.PeerState, address string) bool {
|
||||
address = strings.ToLower(address)
|
||||
|
||||
if strings.EqualFold(peer.GetFqdn(), address) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match against FQDN without trailing dot
|
||||
fqdn := strings.TrimSuffix(peer.GetFqdn(), ".")
|
||||
if strings.EqualFold(fqdn, address) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match against short hostname (first part of FQDN)
|
||||
if parts := strings.SplitN(fqdn, ".", 2); len(parts) > 0 {
|
||||
if strings.EqualFold(parts[0], address) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Match against IP
|
||||
ip := peer.GetIP()
|
||||
if idx := strings.Index(ip, "/"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
if ip == address {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
13
client/cmd/rdp_stub.go
Normal file
13
client/cmd/rdp_stub.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package cmd
|
||||
|
||||
import "fmt"
|
||||
|
||||
// launchRDPClient is a stub for non-Windows platforms.
|
||||
func launchRDPClient(peerIP string) error {
|
||||
fmt.Printf("RDP session authorized for %s\n", peerIP)
|
||||
fmt.Println("Note: mstsc.exe is only available on Windows.")
|
||||
fmt.Printf("Use any RDP client to connect to %s:3389\n", peerIP)
|
||||
return nil
|
||||
}
|
||||
34
client/cmd/rdp_windows.go
Normal file
34
client/cmd/rdp_windows.go
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// launchRDPClient launches the native Windows Remote Desktop client (mstsc.exe).
|
||||
func launchRDPClient(peerIP string) error {
|
||||
mstscPath, err := exec.LookPath("mstsc.exe")
|
||||
if err != nil {
|
||||
return fmt.Errorf("mstsc.exe not found: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command(mstscPath, fmt.Sprintf("/v:%s", peerIP))
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start mstsc.exe: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("launched mstsc.exe (PID %d) connecting to %s", cmd.Process.Pid, peerIP)
|
||||
|
||||
// Don't wait for mstsc to exit - it runs independently
|
||||
go func() {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Debugf("mstsc.exe exited: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -150,6 +150,7 @@ func init() {
|
||||
rootCmd.AddCommand(logoutCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(rdpCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
@@ -197,6 +197,7 @@ type Engine struct {
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
rdpServer rdpServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
|
||||
123
client/internal/engine_rdp.go
Normal file
123
client/internal/engine_rdp.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||
)
|
||||
|
||||
type rdpServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort) error
|
||||
Stop() error
|
||||
GetPendingStore() *rdpserver.PendingStore
|
||||
}
|
||||
|
||||
func (e *Engine) setupRDPPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
|
||||
return fmt.Errorf("add RDP auth port redirection: %w", err)
|
||||
}
|
||||
log.Infof("RDP auth port redirection enabled: %s:%d -> %s:%d",
|
||||
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupRDPPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
|
||||
return fmt.Errorf("remove RDP auth port redirection: %w", err)
|
||||
}
|
||||
log.Debugf("RDP auth port redirection removed: %s:%d -> %s:%d",
|
||||
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) startRDPServer() error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
|
||||
cfg := &rdpserver.Config{
|
||||
NetworkAddr: wgAddr.Network,
|
||||
}
|
||||
|
||||
server := rdpserver.New(cfg)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, rdpserver.InternalRDPAuthPort)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
return fmt.Errorf("start RDP auth server: %w", err)
|
||||
}
|
||||
|
||||
e.rdpServer = server
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
|
||||
log.Debugf("registered RDP auth service with netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupRDPPortRedirection(); err != nil {
|
||||
log.Warnf("failed to setup RDP auth port redirection: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopRDPServer() error {
|
||||
if e.rdpServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupRDPPortRedirection(); err != nil {
|
||||
log.Warnf("failed to cleanup RDP auth port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
|
||||
log.Debugf("unregistered RDP auth service from netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping RDP auth server")
|
||||
err := e.rdpServer.Stop()
|
||||
e.rdpServer = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
88
client/rdp/client/client.go
Normal file
88
client/rdp/client/client.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultTimeout is the default timeout for sideband auth requests.
|
||||
DefaultTimeout = 30 * time.Second
|
||||
|
||||
// maxResponseSize is the maximum size of an auth response in bytes.
|
||||
maxResponseSize = 64 * 1024
|
||||
)
|
||||
|
||||
// Client connects to a target peer's RDP sideband auth server to request access.
|
||||
type Client struct {
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// New creates a new sideband RDP auth client.
|
||||
func New() *Client {
|
||||
return &Client{
|
||||
Timeout: DefaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestAuth sends an authorization request to the target peer's sideband server
|
||||
// and returns the response. The addr should be in "host:port" format.
|
||||
func (c *Client) RequestAuth(ctx context.Context, addr string, req *rdpserver.AuthRequest) (*rdpserver.AuthResponse, error) {
|
||||
timeout := c.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = DefaultTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to RDP auth server at %s: %w", addr, err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
deadline, ok := ctx.Deadline()
|
||||
if ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("set connection deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Send request
|
||||
reqData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal auth request: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.Write(reqData); err != nil {
|
||||
return nil, fmt.Errorf("send auth request: %w", err)
|
||||
}
|
||||
|
||||
// Signal we're done writing so the server can read the full request
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
if err := tcpConn.CloseWrite(); err != nil {
|
||||
return nil, fmt.Errorf("close write: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read response
|
||||
respData, err := io.ReadAll(io.LimitReader(conn, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read auth response: %w", err)
|
||||
}
|
||||
|
||||
var resp rdpserver.AuthResponse
|
||||
if err := json.Unmarshal(respData, &resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal auth response: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
31
client/rdp/credprov/Cargo.toml
Normal file
31
client/rdp/credprov/Cargo.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[package]
|
||||
name = "netbird-credprov"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "NetBird RDP Credential Provider for Windows"
|
||||
license = "BSD-3-Clause"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
windows = { version = "0.58", features = [
|
||||
"implement",
|
||||
"Win32_Foundation",
|
||||
"Win32_System_Com",
|
||||
"Win32_UI_Shell",
|
||||
"Win32_Security",
|
||||
"Win32_Security_Authentication_Identity",
|
||||
"Win32_Security_Credentials",
|
||||
"Win32_System_RemoteDesktop",
|
||||
"Win32_System_Threading",
|
||||
] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
log = "0.4"
|
||||
|
||||
[profile.release]
|
||||
opt-level = "s"
|
||||
lto = true
|
||||
strip = true
|
||||
210
client/rdp/credprov/src/credential.rs
Normal file
210
client/rdp/credprov/src/credential.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
//! ICredentialProviderCredential implementation.
|
||||
//!
|
||||
//! Represents a single "NetBird Login" credential tile on the Windows login screen.
|
||||
//! When selected, it queries the local NetBird agent for pending RDP sessions and
|
||||
//! performs S4U logon to authenticate the user without a password.
|
||||
|
||||
use crate::named_pipe_client::{NamedPipeClient, PipeResponse};
|
||||
use crate::s4u;
|
||||
use std::sync::Mutex;
|
||||
use windows::core::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::Win32::Security::Credentials::*;
|
||||
use windows::Win32::UI::Shell::*;
|
||||
|
||||
/// NetBird credential tile that appears on the Windows login screen.
|
||||
#[implement(ICredentialProviderCredential)]
|
||||
pub struct NetBirdCredential {
|
||||
/// The pending session information from the NetBird agent.
|
||||
session: Mutex<Option<PipeResponse>>,
|
||||
/// The remote IP address of the connecting peer.
|
||||
remote_ip: Mutex<String>,
|
||||
}
|
||||
|
||||
impl NetBirdCredential {
|
||||
pub fn new(remote_ip: String, session: PipeResponse) -> Self {
|
||||
Self {
|
||||
session: Mutex::new(Some(session)),
|
||||
remote_ip: Mutex::new(remote_ip),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ICredentialProviderCredential_Impl for NetBirdCredential_Impl {
|
||||
fn Advise(&self, _pcpce: Option<&ICredentialProviderCredentialEvents>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn UnAdvise(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn SetSelected(&self, _pbautologon: *mut BOOL) -> Result<()> {
|
||||
// Auto-logon when this credential is selected
|
||||
unsafe {
|
||||
if !_pbautologon.is_null() {
|
||||
*_pbautologon = TRUE;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn SetDeselected(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetFieldState(
|
||||
&self,
|
||||
_dwfieldid: u32,
|
||||
_pcpfs: *mut CREDENTIAL_PROVIDER_FIELD_STATE,
|
||||
_pcpfis: *mut CREDENTIAL_PROVIDER_FIELD_INTERACTIVE_STATE,
|
||||
) -> Result<()> {
|
||||
// We have a single display-only field showing "NetBird Login"
|
||||
unsafe {
|
||||
if !_pcpfs.is_null() {
|
||||
*_pcpfs = CPFS_DISPLAY_IN_SELECTED_TILE;
|
||||
}
|
||||
if !_pcpfis.is_null() {
|
||||
*_pcpfis = CPFIS_NONE;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetStringValue(&self, _dwfieldid: u32) -> Result<PWSTR> {
|
||||
let session = self.session.lock().unwrap();
|
||||
let text = if let Some(ref s) = *session {
|
||||
format!("NetBird: Logging in as {}", s.os_user)
|
||||
} else {
|
||||
"NetBird Login".to_string()
|
||||
};
|
||||
|
||||
let wide: Vec<u16> = text.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let ptr = unsafe {
|
||||
let mem = windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
|
||||
if mem.is_null() {
|
||||
return Err(E_OUTOFMEMORY.into());
|
||||
}
|
||||
std::ptr::copy_nonoverlapping(wide.as_ptr(), mem, wide.len());
|
||||
PWSTR(mem)
|
||||
};
|
||||
|
||||
Ok(ptr)
|
||||
}
|
||||
|
||||
fn GetBitmapValue(&self, _dwfieldid: u32) -> Result<HBITMAP> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetCheckboxValue(&self, _dwfieldid: u32, _pbchecked: *mut BOOL, _ppszlabel: *mut PWSTR) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetSubmitButtonValue(&self, _dwfieldid: u32, _pdwadjacentto: *mut u32) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetComboBoxValueCount(&self, _dwfieldid: u32, _pcitems: *mut u32, _pdwselecteditem: *mut u32) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetComboBoxValueAt(&self, _dwfieldid: u32, _dwitem: u32) -> Result<PWSTR> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn SetStringValue(&self, _dwfieldid: u32, _psz: &PCWSTR) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn SetCheckboxValue(&self, _dwfieldid: u32, _bchecked: BOOL) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn SetComboBoxSelectedValue(&self, _dwfieldid: u32, _dwselecteditem: u32) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn CommandLinkClicked(&self, _dwfieldid: u32) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn GetSerialization(
|
||||
&self,
|
||||
_pcpgsr: *mut CREDENTIAL_PROVIDER_GET_SERIALIZATION_RESPONSE,
|
||||
_pcpcs: *mut CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
|
||||
_ppszoptionalstatustext: *mut PWSTR,
|
||||
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
|
||||
) -> Result<()> {
|
||||
let session = self.session.lock().unwrap();
|
||||
let session_info = match &*session {
|
||||
Some(s) => s.clone(),
|
||||
None => {
|
||||
unsafe {
|
||||
*_pcpgsr = CPGSR_NO_CREDENTIAL_NOT_FINISHED;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// Consume the session with the agent
|
||||
if let Err(e) = NamedPipeClient::consume_session(&session_info.session_id) {
|
||||
log::error!("Failed to consume RDP session: {}", e);
|
||||
unsafe {
|
||||
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Perform S4U logon
|
||||
let username = &session_info.os_user;
|
||||
let domain = if session_info.domain.is_empty() {
|
||||
"."
|
||||
} else {
|
||||
&session_info.domain
|
||||
};
|
||||
|
||||
match s4u::generate_s4u_token(username, domain) {
|
||||
Ok(_token) => {
|
||||
// In a full implementation, we would serialize the token into
|
||||
// CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION format
|
||||
// (KerbInteractiveLogon or MsV1_0InteractiveLogon structure).
|
||||
//
|
||||
// For the POC, we signal success. The actual serialization requires
|
||||
// building the proper KERB_INTERACTIVE_LOGON or MSV1_0_INTERACTIVE_LOGON
|
||||
// structure with the token handle, which is complex.
|
||||
//
|
||||
// TODO: Build proper credential serialization from S4U token
|
||||
log::info!(
|
||||
"S4U logon successful for {}\\{}, session {}",
|
||||
domain,
|
||||
username,
|
||||
session_info.session_id
|
||||
);
|
||||
|
||||
unsafe {
|
||||
*_pcpgsr = CPGSR_RETURN_CREDENTIAL_FINISHED;
|
||||
// Note: In production, pcpcs would be filled with the serialized credentials
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("S4U logon failed for {}\\{}: {}", domain, username, e);
|
||||
unsafe {
|
||||
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ReportResult(
|
||||
&self,
|
||||
_ntstatus: NTSTATUS,
|
||||
_ntssubstatus: NTSTATUS,
|
||||
_ppszoptionalstatustext: *mut PWSTR,
|
||||
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
11
client/rdp/credprov/src/guid.rs
Normal file
11
client/rdp/credprov/src/guid.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use windows::core::GUID;
|
||||
|
||||
/// CLSID for the NetBird RDP Credential Provider.
|
||||
/// Generated UUID: {7B3A8E5F-1C4D-4F8A-B2E6-9D0F3A7C5E1B}
|
||||
pub const CLSID_NETBIRD_CREDENTIAL_PROVIDER: GUID = GUID::from_u128(
|
||||
0x7B3A8E5F_1C4D_4F8A_B2E6_9D0F3A7C5E1B,
|
||||
);
|
||||
|
||||
/// Registry path for credential providers.
|
||||
pub const CREDENTIAL_PROVIDER_REGISTRY_PATH: &str =
|
||||
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers";
|
||||
309
client/rdp/credprov/src/lib.rs
Normal file
309
client/rdp/credprov/src/lib.rs
Normal file
@@ -0,0 +1,309 @@
|
||||
//! NetBird RDP Credential Provider for Windows.
|
||||
//!
|
||||
//! This DLL is a Windows Credential Provider that enables passwordless RDP access
|
||||
//! to machines running the NetBird agent. It is loaded by Windows' LogonUI.exe
|
||||
//! via COM when the login screen is displayed.
|
||||
//!
|
||||
//! ## How it works
|
||||
//!
|
||||
//! 1. The DLL is registered as a Credential Provider in the Windows registry
|
||||
//! 2. When an RDP session begins, LogonUI loads the DLL
|
||||
//! 3. The DLL queries the local NetBird agent via named pipe for pending sessions
|
||||
//! 4. If a pending session exists for the connecting peer, the DLL:
|
||||
//! - Shows a "NetBird Login" credential tile
|
||||
//! - Performs S4U logon to create a Windows token without a password
|
||||
//! - Returns the token to LogonUI for session creation
|
||||
|
||||
mod credential;
|
||||
mod guid;
|
||||
mod named_pipe_client;
|
||||
mod provider;
|
||||
mod s4u;
|
||||
|
||||
use guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
|
||||
use provider::NetBirdCredentialProvider;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use windows::core::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::Win32::System::Com::*;
|
||||
|
||||
/// DLL reference count for COM lifecycle management.
|
||||
static DLL_REF_COUNT: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
/// DLL module handle.
|
||||
static mut DLL_MODULE: HMODULE = HMODULE(std::ptr::null_mut());
|
||||
|
||||
/// COM class factory for creating NetBirdCredentialProvider instances.
|
||||
#[implement(IClassFactory)]
|
||||
struct NetBirdClassFactory;
|
||||
|
||||
impl IClassFactory_Impl for NetBirdClassFactory_Impl {
|
||||
fn CreateInstance(
|
||||
&self,
|
||||
_punkouter: Option<&IUnknown>,
|
||||
riid: *const GUID,
|
||||
ppvobject: *mut *mut std::ffi::c_void,
|
||||
) -> Result<()> {
|
||||
unsafe {
|
||||
if !ppvobject.is_null() {
|
||||
*ppvobject = std::ptr::null_mut();
|
||||
}
|
||||
}
|
||||
|
||||
if _punkouter.is_some() {
|
||||
return Err(CLASS_E_NOAGGREGATION.into());
|
||||
}
|
||||
|
||||
let provider = NetBirdCredentialProvider::new();
|
||||
let unknown: IUnknown = provider.into();
|
||||
|
||||
unsafe {
|
||||
unknown.query(riid, ppvobject).ok()
|
||||
}
|
||||
}
|
||||
|
||||
fn LockServer(&self, flock: BOOL) -> Result<()> {
|
||||
if flock.as_bool() {
|
||||
DLL_REF_COUNT.fetch_add(1, Ordering::SeqCst);
|
||||
} else {
|
||||
DLL_REF_COUNT.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// DLL entry point.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllMain(hinstance: HMODULE, reason: u32, _reserved: *mut std::ffi::c_void) -> BOOL {
|
||||
const DLL_PROCESS_ATTACH: u32 = 1;
|
||||
|
||||
if reason == DLL_PROCESS_ATTACH {
|
||||
unsafe {
|
||||
DLL_MODULE = hinstance;
|
||||
}
|
||||
}
|
||||
|
||||
TRUE
|
||||
}
|
||||
|
||||
/// COM entry point: returns a class factory for the requested CLSID.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllGetClassObject(
|
||||
rclsid: *const GUID,
|
||||
riid: *const GUID,
|
||||
ppv: *mut *mut std::ffi::c_void,
|
||||
) -> HRESULT {
|
||||
unsafe {
|
||||
if ppv.is_null() {
|
||||
return E_POINTER;
|
||||
}
|
||||
*ppv = std::ptr::null_mut();
|
||||
|
||||
if *rclsid != CLSID_NETBIRD_CREDENTIAL_PROVIDER {
|
||||
return CLASS_E_CLASSNOTAVAILABLE;
|
||||
}
|
||||
|
||||
let factory = NetBirdClassFactory;
|
||||
let unknown: IUnknown = factory.into();
|
||||
|
||||
match unknown.query(riid, ppv) {
|
||||
Ok(()) => S_OK,
|
||||
Err(e) => e.code(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// COM entry point: indicates whether the DLL can be unloaded.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllCanUnloadNow() -> HRESULT {
|
||||
if DLL_REF_COUNT.load(Ordering::SeqCst) == 0 {
|
||||
S_OK
|
||||
} else {
|
||||
S_FALSE
|
||||
}
|
||||
}
|
||||
|
||||
/// Self-registration: called by regsvr32 to register the credential provider.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllRegisterServer() -> HRESULT {
|
||||
match register_credential_provider(true) {
|
||||
Ok(()) => S_OK,
|
||||
Err(_) => E_FAIL,
|
||||
}
|
||||
}
|
||||
|
||||
/// Self-unregistration: called by regsvr32 /u to unregister the credential provider.
|
||||
#[no_mangle]
|
||||
extern "system" fn DllUnregisterServer() -> HRESULT {
|
||||
match register_credential_provider(false) {
|
||||
Ok(()) => S_OK,
|
||||
Err(_) => E_FAIL,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_credential_provider(register: bool) -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
use windows::Win32::System::Registry::*;
|
||||
|
||||
let clsid_str = format!("{{{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}",
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data1,
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data2,
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data3,
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[0],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[1],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[2],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[3],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[4],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[5],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[6],
|
||||
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[7],
|
||||
);
|
||||
|
||||
if register {
|
||||
// Register under Credential Providers
|
||||
let cp_key_path = format!(
|
||||
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
|
||||
clsid_str
|
||||
);
|
||||
|
||||
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let mut hkey = HKEY::default();
|
||||
|
||||
unsafe {
|
||||
let result = RegCreateKeyExW(
|
||||
HKEY_LOCAL_MACHINE,
|
||||
PCWSTR(cp_key_wide.as_ptr()),
|
||||
0,
|
||||
PCWSTR::null(),
|
||||
REG_OPTION_NON_VOLATILE,
|
||||
KEY_WRITE,
|
||||
None,
|
||||
&mut hkey,
|
||||
None,
|
||||
);
|
||||
if result.is_err() {
|
||||
return Err("Failed to create credential provider registry key".into());
|
||||
}
|
||||
|
||||
let value: Vec<u16> = "NetBird RDP Credential Provider"
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect();
|
||||
let _ = RegSetValueExW(
|
||||
hkey,
|
||||
PCWSTR::null(),
|
||||
0,
|
||||
REG_SZ,
|
||||
Some(std::slice::from_raw_parts(
|
||||
value.as_ptr() as *const u8,
|
||||
value.len() * 2,
|
||||
)),
|
||||
);
|
||||
let _ = RegCloseKey(hkey);
|
||||
}
|
||||
|
||||
// Register CLSID in CLSID hive
|
||||
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
|
||||
let clsid_key_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
unsafe {
|
||||
let result = RegCreateKeyExW(
|
||||
HKEY_CLASSES_ROOT,
|
||||
PCWSTR(clsid_key_wide.as_ptr()),
|
||||
0,
|
||||
PCWSTR::null(),
|
||||
REG_OPTION_NON_VOLATILE,
|
||||
KEY_WRITE,
|
||||
None,
|
||||
&mut hkey,
|
||||
None,
|
||||
);
|
||||
if result.is_err() {
|
||||
return Err("Failed to create CLSID registry key".into());
|
||||
}
|
||||
let _ = RegCloseKey(hkey);
|
||||
|
||||
// InprocServer32 subkey
|
||||
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
|
||||
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
let result = RegCreateKeyExW(
|
||||
HKEY_CLASSES_ROOT,
|
||||
PCWSTR(inproc_wide.as_ptr()),
|
||||
0,
|
||||
PCWSTR::null(),
|
||||
REG_OPTION_NON_VOLATILE,
|
||||
KEY_WRITE,
|
||||
None,
|
||||
&mut hkey,
|
||||
None,
|
||||
);
|
||||
if result.is_err() {
|
||||
return Err("Failed to create InprocServer32 registry key".into());
|
||||
}
|
||||
|
||||
// Set DLL path
|
||||
let mut dll_path = [0u16; 260];
|
||||
let len = windows::Win32::System::LibraryLoader::GetModuleFileNameW(
|
||||
DLL_MODULE,
|
||||
&mut dll_path,
|
||||
);
|
||||
if len > 0 {
|
||||
let _ = RegSetValueExW(
|
||||
hkey,
|
||||
PCWSTR::null(),
|
||||
0,
|
||||
REG_SZ,
|
||||
Some(std::slice::from_raw_parts(
|
||||
dll_path.as_ptr() as *const u8,
|
||||
(len as usize + 1) * 2,
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
// Set threading model
|
||||
let threading: Vec<u16> = "Apartment"
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect();
|
||||
let threading_name: Vec<u16> = "ThreadingModel"
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect();
|
||||
let _ = RegSetValueExW(
|
||||
hkey,
|
||||
PCWSTR(threading_name.as_ptr()),
|
||||
0,
|
||||
REG_SZ,
|
||||
Some(std::slice::from_raw_parts(
|
||||
threading.as_ptr() as *const u8,
|
||||
threading.len() * 2,
|
||||
)),
|
||||
);
|
||||
|
||||
let _ = RegCloseKey(hkey);
|
||||
}
|
||||
} else {
|
||||
// Unregister
|
||||
let cp_key_path = format!(
|
||||
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
|
||||
clsid_str
|
||||
);
|
||||
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
unsafe {
|
||||
let _ = RegDeleteKeyW(HKEY_LOCAL_MACHINE, PCWSTR(cp_key_wide.as_ptr()));
|
||||
}
|
||||
|
||||
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
|
||||
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
|
||||
let clsid_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
unsafe {
|
||||
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(inproc_wide.as_ptr()));
|
||||
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(clsid_wide.as_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
135
client/rdp/credprov/src/named_pipe_client.rs
Normal file
135
client/rdp/credprov/src/named_pipe_client.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::{Read, Write};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Named pipe path for communicating with the NetBird agent.
|
||||
const PIPE_NAME: &str = r"\\.\pipe\netbird-rdp-auth";
|
||||
|
||||
/// Maximum response size from the agent.
|
||||
const MAX_RESPONSE_SIZE: usize = 4096;
|
||||
|
||||
/// Timeout for named pipe operations.
|
||||
const PIPE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Request sent to the NetBird agent via named pipe.
|
||||
#[derive(Serialize)]
|
||||
pub struct PipeRequest {
|
||||
pub action: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub remote_ip: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Response received from the NetBird agent via named pipe.
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct PipeResponse {
|
||||
pub found: bool,
|
||||
#[serde(default)]
|
||||
pub session_id: String,
|
||||
#[serde(default)]
|
||||
pub os_user: String,
|
||||
#[serde(default)]
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
/// Client for communicating with the NetBird agent's named pipe server.
|
||||
pub struct NamedPipeClient;
|
||||
|
||||
impl NamedPipeClient {
|
||||
/// Query the NetBird agent for a pending RDP session matching the given remote IP.
|
||||
pub fn query_pending(remote_ip: &str) -> Result<PipeResponse, PipeError> {
|
||||
let request = PipeRequest {
|
||||
action: "query_pending".to_string(),
|
||||
remote_ip: Some(remote_ip.to_string()),
|
||||
session_id: None,
|
||||
};
|
||||
Self::send_request(&request)
|
||||
}
|
||||
|
||||
/// Tell the NetBird agent to consume (mark as used) a pending session.
|
||||
pub fn consume_session(session_id: &str) -> Result<PipeResponse, PipeError> {
|
||||
let request = PipeRequest {
|
||||
action: "consume".to_string(),
|
||||
remote_ip: None,
|
||||
session_id: Some(session_id.to_string()),
|
||||
};
|
||||
Self::send_request(&request)
|
||||
}
|
||||
|
||||
fn send_request(request: &PipeRequest) -> Result<PipeResponse, PipeError> {
|
||||
let request_data =
|
||||
serde_json::to_vec(request).map_err(|e| PipeError::Serialization(e.to_string()))?;
|
||||
|
||||
// Open named pipe (CreateFile in Windows)
|
||||
let mut pipe = Self::open_pipe()?;
|
||||
|
||||
// Write request
|
||||
pipe.write_all(&request_data)
|
||||
.map_err(|e| PipeError::Write(e.to_string()))?;
|
||||
|
||||
// Shutdown write side to signal end of request
|
||||
// For named pipes on Windows, we rely on the message boundary
|
||||
pipe.flush()
|
||||
.map_err(|e| PipeError::Write(e.to_string()))?;
|
||||
|
||||
// Read response
|
||||
let mut response_data = vec![0u8; MAX_RESPONSE_SIZE];
|
||||
let n = pipe
|
||||
.read(&mut response_data)
|
||||
.map_err(|e| PipeError::Read(e.to_string()))?;
|
||||
|
||||
let response: PipeResponse = serde_json::from_slice(&response_data[..n])
|
||||
.map_err(|e| PipeError::Deserialization(e.to_string()))?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn open_pipe() -> Result<std::fs::File, PipeError> {
|
||||
// On Windows, named pipes are opened like files
|
||||
use std::fs::OpenOptions;
|
||||
|
||||
// Try to open the pipe with a brief retry for PIPE_BUSY
|
||||
for attempt in 0..3 {
|
||||
match OpenOptions::new().read(true).write(true).open(PIPE_NAME) {
|
||||
Ok(file) => return Ok(file),
|
||||
Err(e) => {
|
||||
if attempt < 2 {
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
return Err(PipeError::Connect(format!(
|
||||
"failed to open pipe {}: {}",
|
||||
PIPE_NAME, e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(PipeError::Connect("exhausted pipe connection attempts".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that can occur during named pipe communication.
|
||||
#[derive(Debug)]
|
||||
pub enum PipeError {
|
||||
Connect(String),
|
||||
Write(String),
|
||||
Read(String),
|
||||
Serialization(String),
|
||||
Deserialization(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PipeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PipeError::Connect(e) => write!(f, "pipe connect: {}", e),
|
||||
PipeError::Write(e) => write!(f, "pipe write: {}", e),
|
||||
PipeError::Read(e) => write!(f, "pipe read: {}", e),
|
||||
PipeError::Serialization(e) => write!(f, "pipe serialization: {}", e),
|
||||
PipeError::Deserialization(e) => write!(f, "pipe deserialization: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PipeError {}
|
||||
270
client/rdp/credprov/src/provider.rs
Normal file
270
client/rdp/credprov/src/provider.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
//! ICredentialProvider implementation.
|
||||
//!
|
||||
//! This is the main COM object that Windows' LogonUI.exe instantiates.
|
||||
//! It determines whether to show a "NetBird Login" credential tile based on
|
||||
//! whether the NetBird agent has a pending RDP session for the connecting peer.
|
||||
|
||||
use crate::credential::NetBirdCredential;
|
||||
use crate::guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
|
||||
use crate::named_pipe_client::NamedPipeClient;
|
||||
use std::sync::Mutex;
|
||||
use windows::core::*;
|
||||
use windows::Win32::Foundation::*;
|
||||
use windows::Win32::Security::Credentials::*;
|
||||
use windows::Win32::System::RemoteDesktop::*;
|
||||
|
||||
/// The NetBird Credential Provider, loaded by LogonUI.exe via COM.
|
||||
#[implement(ICredentialProvider)]
|
||||
pub struct NetBirdCredentialProvider {
|
||||
/// The credential tile (if a pending session was found).
|
||||
credential: Mutex<Option<ICredentialProviderCredential>>,
|
||||
/// Whether this provider is active for the current usage scenario.
|
||||
active: Mutex<bool>,
|
||||
}
|
||||
|
||||
impl NetBirdCredentialProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
credential: Mutex::new(None),
|
||||
active: Mutex::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ICredentialProvider_Impl for NetBirdCredentialProvider_Impl {
|
||||
fn SetUsageScenario(
|
||||
&self,
|
||||
cpus: CREDENTIAL_PROVIDER_USAGE_SCENARIO,
|
||||
_dwflags: u32,
|
||||
) -> Result<()> {
|
||||
let mut active = self.active.lock().unwrap();
|
||||
|
||||
match cpus {
|
||||
CPUS_LOGON | CPUS_UNLOCK_WORKSTATION => {
|
||||
// We activate for RDP logon and unlock scenarios
|
||||
*active = true;
|
||||
log::info!("NetBird CP activated for usage scenario {:?}", cpus.0);
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
// Don't activate for credui or other scenarios
|
||||
*active = false;
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn SetSerialization(
|
||||
&self,
|
||||
_pcpcs: *const CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
|
||||
) -> Result<()> {
|
||||
Err(E_NOTIMPL.into())
|
||||
}
|
||||
|
||||
fn Advise(
|
||||
&self,
|
||||
_pcpe: Option<&ICredentialProviderEvents>,
|
||||
_upadvisecontext: usize,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn UnAdvise(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetFieldDescriptorCount(&self) -> Result<u32> {
|
||||
// We have one field: a large text label showing "NetBird: Logging in as <user>"
|
||||
Ok(1)
|
||||
}
|
||||
|
||||
fn GetFieldDescriptorAt(
|
||||
&self,
|
||||
_dwindex: u32,
|
||||
_ppcpfd: *mut *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR,
|
||||
) -> Result<()> {
|
||||
if _dwindex != 0 {
|
||||
return Err(E_INVALIDARG.into());
|
||||
}
|
||||
|
||||
let label = "NetBird Login";
|
||||
let wide: Vec<u16> = label.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
unsafe {
|
||||
let desc = windows::Win32::System::Com::CoTaskMemAlloc(
|
||||
std::mem::size_of::<CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR>(),
|
||||
) as *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR;
|
||||
|
||||
if desc.is_null() {
|
||||
return Err(E_OUTOFMEMORY.into());
|
||||
}
|
||||
|
||||
let label_mem =
|
||||
windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
|
||||
if label_mem.is_null() {
|
||||
windows::Win32::System::Com::CoTaskMemFree(Some(desc as *const _));
|
||||
return Err(E_OUTOFMEMORY.into());
|
||||
}
|
||||
std::ptr::copy_nonoverlapping(wide.as_ptr(), label_mem, wide.len());
|
||||
|
||||
(*desc).dwFieldID = 0;
|
||||
(*desc).cpft = CPFT_LARGE_TEXT;
|
||||
(*desc).pszLabel = PWSTR(label_mem);
|
||||
(*desc).guidFieldType = GUID::zeroed();
|
||||
|
||||
*_ppcpfd = desc;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetCredentialCount(
|
||||
&self,
|
||||
_pdwcount: *mut u32,
|
||||
_pdwdefault: *mut u32,
|
||||
_pbautologinwithdefault: *mut BOOL,
|
||||
) -> Result<()> {
|
||||
let active = self.active.lock().unwrap();
|
||||
if !*active {
|
||||
unsafe {
|
||||
*_pdwcount = 0;
|
||||
*_pdwdefault = u32::MAX;
|
||||
*_pbautologinwithdefault = FALSE;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Try to get the client IP of the current RDP session
|
||||
let remote_ip = match get_rdp_client_ip() {
|
||||
Some(ip) => ip,
|
||||
None => {
|
||||
log::debug!("NetBird CP: could not determine RDP client IP");
|
||||
unsafe {
|
||||
*_pdwcount = 0;
|
||||
*_pdwdefault = u32::MAX;
|
||||
*_pbautologinwithdefault = FALSE;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// Query the NetBird agent for a pending session
|
||||
match NamedPipeClient::query_pending(&remote_ip) {
|
||||
Ok(response) if response.found => {
|
||||
log::info!(
|
||||
"NetBird CP: found pending session for {} -> {}",
|
||||
remote_ip,
|
||||
response.os_user
|
||||
);
|
||||
|
||||
let cred = NetBirdCredential::new(remote_ip, response);
|
||||
let icred: ICredentialProviderCredential = cred.into();
|
||||
|
||||
let mut credential = self.credential.lock().unwrap();
|
||||
*credential = Some(icred);
|
||||
|
||||
unsafe {
|
||||
*_pdwcount = 1;
|
||||
*_pdwdefault = 0;
|
||||
*_pbautologinwithdefault = TRUE; // auto-logon
|
||||
}
|
||||
}
|
||||
Ok(_) => {
|
||||
log::debug!("NetBird CP: no pending session for {}", remote_ip);
|
||||
unsafe {
|
||||
*_pdwcount = 0;
|
||||
*_pdwdefault = u32::MAX;
|
||||
*_pbautologinwithdefault = FALSE;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::debug!("NetBird CP: pipe query failed: {}", e);
|
||||
unsafe {
|
||||
*_pdwcount = 0;
|
||||
*_pdwdefault = u32::MAX;
|
||||
*_pbautologinwithdefault = FALSE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn GetCredentialAt(
|
||||
&self,
|
||||
_dwindex: u32,
|
||||
_ppcpc: *mut Option<ICredentialProviderCredential>,
|
||||
) -> Result<()> {
|
||||
if _dwindex != 0 {
|
||||
return Err(E_INVALIDARG.into());
|
||||
}
|
||||
|
||||
let credential = self.credential.lock().unwrap();
|
||||
match &*credential {
|
||||
Some(cred) => {
|
||||
unsafe {
|
||||
*_ppcpc = Some(cred.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
None => Err(E_UNEXPECTED.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the IP address of the remote RDP client for the current session.
|
||||
fn get_rdp_client_ip() -> Option<String> {
|
||||
unsafe {
|
||||
// Get the current session ID
|
||||
let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
|
||||
let mut session_id = 0u32;
|
||||
|
||||
if !windows::Win32::System::RemoteDesktop::ProcessIdToSessionId(process_id, &mut session_id)
|
||||
.as_bool()
|
||||
{
|
||||
log::debug!("ProcessIdToSessionId failed");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Query the client address
|
||||
let mut buffer: *mut WTS_CLIENT_ADDRESS = std::ptr::null_mut();
|
||||
let mut bytes_returned = 0u32;
|
||||
|
||||
let result = WTSQuerySessionInformationW(
|
||||
WTS_CURRENT_SERVER_HANDLE,
|
||||
session_id,
|
||||
WTS_INFO_CLASS(14), // WTSClientAddress
|
||||
&mut buffer as *mut _ as *mut *mut u16,
|
||||
&mut bytes_returned,
|
||||
);
|
||||
|
||||
if !result.as_bool() || buffer.is_null() {
|
||||
log::debug!("WTSQuerySessionInformation(WTSClientAddress) failed");
|
||||
return None;
|
||||
}
|
||||
|
||||
let client_addr = &*buffer;
|
||||
let ip = match client_addr.AddressFamily as u32 {
|
||||
// AF_INET
|
||||
2 => {
|
||||
let addr = &client_addr.Address;
|
||||
Some(format!("{}.{}.{}.{}", addr[2], addr[3], addr[4], addr[5]))
|
||||
}
|
||||
// AF_INET6
|
||||
23 => {
|
||||
// IPv6 - extract from Address bytes
|
||||
let addr = &client_addr.Address;
|
||||
Some(format!(
|
||||
"{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
|
||||
addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9],
|
||||
addr[10], addr[11], addr[12], addr[13], addr[14], addr[15], addr[16], addr[17]
|
||||
))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
WTSFreeMemory(buffer as *mut std::ffi::c_void);
|
||||
|
||||
ip
|
||||
}
|
||||
}
|
||||
398
client/rdp/credprov/src/s4u.rs
Normal file
398
client/rdp/credprov/src/s4u.rs
Normal file
@@ -0,0 +1,398 @@
|
||||
//! S4U (Service for User) authentication for Windows.
|
||||
//!
|
||||
//! This module ports the S4U logon logic from the Go implementation at:
|
||||
//! `client/ssh/server/executor_windows.go:generateS4UUserToken()`
|
||||
//!
|
||||
//! It creates Windows logon tokens without requiring a password, using the LSA
|
||||
//! (Local Security Authority) S4U mechanism. This is the same approach used by
|
||||
//! OpenSSH for Windows for public key authentication.
|
||||
|
||||
use std::ptr;
|
||||
use windows::core::{PCSTR, PWSTR};
|
||||
use windows::Win32::Foundation::{HANDLE, LUID, NTSTATUS, PSID};
|
||||
use windows::Win32::Security::Authentication::Identity::{
|
||||
LsaDeregisterLogonProcess, LsaFreeReturnBuffer, LsaLogonUser, LsaLookupAuthenticationPackage,
|
||||
LsaRegisterLogonProcess, KERB_S4U_LOGON, MSV1_0_S4U_LOGON, MSV1_0_S4U_LOGON_FLAG_CHECK_LOGONHOURS,
|
||||
SECURITY_LOGON_TYPE,
|
||||
};
|
||||
use windows::Win32::Security::{
|
||||
QUOTA_LIMITS, TOKEN_SOURCE,
|
||||
};
|
||||
|
||||
/// Status code for successful LSA operations.
|
||||
const STATUS_SUCCESS: i32 = 0;
|
||||
|
||||
/// Network logon type (used for S4U).
|
||||
const LOGON32_LOGON_NETWORK: SECURITY_LOGON_TYPE = SECURITY_LOGON_TYPE(3);
|
||||
|
||||
/// Kerberos S4U logon message type.
|
||||
const KERB_S4U_LOGON_TYPE: u32 = 12;
|
||||
|
||||
/// MSV1_0 S4U logon message type.
|
||||
const MSV1_0_S4U_LOGON_TYPE: u32 = 12;
|
||||
|
||||
/// Authentication package name for Kerberos.
|
||||
const KERBEROS_PACKAGE: &str = "Kerberos";
|
||||
|
||||
/// Authentication package name for MSV1_0 (local users).
|
||||
const MSV1_0_PACKAGE: &str = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0";
|
||||
|
||||
/// Result of a successful S4U logon.
|
||||
pub struct S4UToken {
|
||||
pub handle: HANDLE,
|
||||
}
|
||||
|
||||
impl Drop for S4UToken {
|
||||
fn drop(&mut self) {
|
||||
if !self.handle.is_invalid() {
|
||||
unsafe {
|
||||
let _ = windows::Win32::Foundation::CloseHandle(self.handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors from S4U logon operations.
|
||||
#[derive(Debug)]
|
||||
pub enum S4UError {
|
||||
LsaRegister(NTSTATUS),
|
||||
LookupPackage(NTSTATUS),
|
||||
LogonUser(NTSTATUS, i32),
|
||||
AllocateLuid,
|
||||
InvalidUsername(String),
|
||||
Utf16Conversion(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for S4UError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
S4UError::LsaRegister(s) => write!(f, "LsaRegisterLogonProcess: 0x{:x}", s.0),
|
||||
S4UError::LookupPackage(s) => write!(f, "LsaLookupAuthenticationPackage: 0x{:x}", s.0),
|
||||
S4UError::LogonUser(s, sub) => {
|
||||
write!(f, "LsaLogonUser S4U: NTSTATUS=0x{:x}, SubStatus=0x{:x}", s.0, sub)
|
||||
}
|
||||
S4UError::AllocateLuid => write!(f, "AllocateLocallyUniqueId failed"),
|
||||
S4UError::InvalidUsername(u) => write!(f, "invalid username: {}", u),
|
||||
S4UError::Utf16Conversion(s) => write!(f, "UTF-16 conversion: {}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for S4UError {}
|
||||
|
||||
/// Generate a Windows logon token using S4U authentication.
|
||||
///
|
||||
/// This creates a token for the specified user without requiring a password.
|
||||
/// The calling process must have SeTcbPrivilege (typically SYSTEM).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `username` - The Windows username (without domain prefix)
|
||||
/// * `domain` - The domain name ("." for local users)
|
||||
///
|
||||
/// # Returns
|
||||
/// An `S4UToken` containing the Windows logon token handle.
|
||||
pub fn generate_s4u_token(username: &str, domain: &str) -> Result<S4UToken, S4UError> {
|
||||
if username.is_empty() {
|
||||
return Err(S4UError::InvalidUsername("empty username".to_string()));
|
||||
}
|
||||
|
||||
let is_local = is_local_user(domain);
|
||||
|
||||
// Initialize LSA connection
|
||||
let lsa_handle = initialize_lsa_connection()?;
|
||||
|
||||
// Lookup authentication package
|
||||
let auth_package_id = lookup_auth_package(lsa_handle, is_local)?;
|
||||
|
||||
// Perform S4U logon
|
||||
let result = perform_s4u_logon(lsa_handle, auth_package_id, username, domain, is_local);
|
||||
|
||||
// Cleanup LSA connection
|
||||
unsafe {
|
||||
let _ = LsaDeregisterLogonProcess(lsa_handle);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn is_local_user(domain: &str) -> bool {
|
||||
domain.is_empty() || domain == "."
|
||||
}
|
||||
|
||||
fn initialize_lsa_connection() -> Result<HANDLE, S4UError> {
|
||||
let process_name = "NetBird\0";
|
||||
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||
Length: (process_name.len() - 1) as u16,
|
||||
MaximumLength: process_name.len() as u16,
|
||||
Buffer: windows::core::PSTR(process_name.as_ptr() as *mut u8),
|
||||
};
|
||||
|
||||
let mut lsa_handle = HANDLE::default();
|
||||
let mut mode = 0u32;
|
||||
|
||||
let status = unsafe {
|
||||
LsaRegisterLogonProcess(&mut lsa_string, &mut lsa_handle, &mut mode)
|
||||
};
|
||||
|
||||
if status.0 != STATUS_SUCCESS {
|
||||
return Err(S4UError::LsaRegister(status));
|
||||
}
|
||||
|
||||
Ok(lsa_handle)
|
||||
}
|
||||
|
||||
fn lookup_auth_package(lsa_handle: HANDLE, is_local: bool) -> Result<u32, S4UError> {
|
||||
let package_name = if is_local { MSV1_0_PACKAGE } else { KERBEROS_PACKAGE };
|
||||
let package_with_null = format!("{}\0", package_name);
|
||||
|
||||
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||
Length: (package_with_null.len() - 1) as u16,
|
||||
MaximumLength: package_with_null.len() as u16,
|
||||
Buffer: windows::core::PSTR(package_with_null.as_ptr() as *mut u8),
|
||||
};
|
||||
|
||||
let mut auth_package_id = 0u32;
|
||||
let status = unsafe {
|
||||
LsaLookupAuthenticationPackage(lsa_handle, &mut lsa_string, &mut auth_package_id)
|
||||
};
|
||||
|
||||
if status.0 != STATUS_SUCCESS {
|
||||
return Err(S4UError::LookupPackage(status));
|
||||
}
|
||||
|
||||
Ok(auth_package_id)
|
||||
}
|
||||
|
||||
fn perform_s4u_logon(
|
||||
lsa_handle: HANDLE,
|
||||
auth_package_id: u32,
|
||||
username: &str,
|
||||
domain: &str,
|
||||
is_local: bool,
|
||||
) -> Result<S4UToken, S4UError> {
|
||||
// Prepare token source
|
||||
let mut source_name = [0u8; 8];
|
||||
let name_bytes = b"netbird";
|
||||
source_name[..name_bytes.len()].copy_from_slice(name_bytes);
|
||||
|
||||
let mut source_id = LUID::default();
|
||||
let alloc_ok = unsafe {
|
||||
windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime(
|
||||
&mut std::mem::zeroed(),
|
||||
);
|
||||
// Use a simpler approach - just use the current time as a unique ID
|
||||
source_id.LowPart = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.subsec_nanos();
|
||||
source_id.HighPart = std::process::id() as i32;
|
||||
true
|
||||
};
|
||||
|
||||
if !alloc_ok {
|
||||
return Err(S4UError::AllocateLuid);
|
||||
}
|
||||
|
||||
let token_source = TOKEN_SOURCE {
|
||||
SourceName: source_name,
|
||||
SourceIdentifier: source_id,
|
||||
};
|
||||
|
||||
let origin_name_str = "netbird\0";
|
||||
let mut origin_name = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||
Length: (origin_name_str.len() - 1) as u16,
|
||||
MaximumLength: origin_name_str.len() as u16,
|
||||
Buffer: windows::core::PSTR(origin_name_str.as_ptr() as *mut u8),
|
||||
};
|
||||
|
||||
// Build the logon info structure
|
||||
let (logon_info_ptr, logon_info_size) = if is_local {
|
||||
build_msv1_0_s4u_logon(username)?
|
||||
} else {
|
||||
build_kerb_s4u_logon(username, domain)?
|
||||
};
|
||||
|
||||
let mut profile: *mut std::ffi::c_void = ptr::null_mut();
|
||||
let mut profile_size = 0u32;
|
||||
let mut logon_id = LUID::default();
|
||||
let mut token = HANDLE::default();
|
||||
let mut quotas = QUOTA_LIMITS::default();
|
||||
let mut sub_status: i32 = 0;
|
||||
|
||||
let status = unsafe {
|
||||
LsaLogonUser(
|
||||
lsa_handle,
|
||||
&mut origin_name,
|
||||
LOGON32_LOGON_NETWORK,
|
||||
auth_package_id,
|
||||
logon_info_ptr as *const std::ffi::c_void,
|
||||
logon_info_size as u32,
|
||||
None, // local groups
|
||||
&token_source,
|
||||
&mut profile,
|
||||
&mut profile_size,
|
||||
&mut logon_id,
|
||||
&mut token,
|
||||
&mut quotas,
|
||||
&mut sub_status,
|
||||
)
|
||||
};
|
||||
|
||||
// Free profile buffer if allocated
|
||||
if !profile.is_null() {
|
||||
unsafe {
|
||||
let _ = LsaFreeReturnBuffer(profile);
|
||||
}
|
||||
}
|
||||
|
||||
// Free the logon info buffer
|
||||
unsafe {
|
||||
let layout = std::alloc::Layout::from_size_align_unchecked(logon_info_size, 8);
|
||||
std::alloc::dealloc(logon_info_ptr as *mut u8, layout);
|
||||
}
|
||||
|
||||
if status.0 != STATUS_SUCCESS {
|
||||
return Err(S4UError::LogonUser(status, sub_status));
|
||||
}
|
||||
|
||||
Ok(S4UToken { handle: token })
|
||||
}
|
||||
|
||||
/// Build MSV1_0_S4U_LOGON structure for local users.
|
||||
fn build_msv1_0_s4u_logon(username: &str) -> Result<(*mut u8, usize), S4UError> {
|
||||
let username_utf16: Vec<u16> = username.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let domain_utf16: Vec<u16> = ".".encode_utf16().chain(std::iter::once(0)).collect();
|
||||
|
||||
let username_byte_size = username_utf16.len() * 2;
|
||||
let domain_byte_size = domain_utf16.len() * 2;
|
||||
|
||||
// MSV1_0_S4U_LOGON structure:
|
||||
// MessageType: u32 (4 bytes)
|
||||
// Flags: u32 (4 bytes)
|
||||
// UserPrincipalName: UNICODE_STRING (8 bytes on 32-bit, 16 bytes on 64-bit)
|
||||
// DomainName: UNICODE_STRING
|
||||
let struct_size = std::mem::size_of::<MSV1_0_S4U_LOGON_HEADER>();
|
||||
let total_size = struct_size + username_byte_size + domain_byte_size;
|
||||
|
||||
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
|
||||
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
|
||||
|
||||
if buffer.is_null() {
|
||||
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
|
||||
}
|
||||
|
||||
// For the POC, we'll set up the raw bytes manually since the windows-rs
|
||||
// MSV1_0_S4U_LOGON structure layout may differ.
|
||||
// This is a simplified version - in production, use proper FFI bindings.
|
||||
|
||||
unsafe {
|
||||
// MessageType = MSV1_0_S4U_LOGON_TYPE (12)
|
||||
*(buffer as *mut u32) = MSV1_0_S4U_LOGON_TYPE;
|
||||
// Flags = 0
|
||||
*((buffer as *mut u32).add(1)) = 0;
|
||||
|
||||
// Copy username UTF-16 after the structure
|
||||
let username_offset = struct_size;
|
||||
let username_dest = buffer.add(username_offset);
|
||||
ptr::copy_nonoverlapping(
|
||||
username_utf16.as_ptr() as *const u8,
|
||||
username_dest,
|
||||
username_byte_size,
|
||||
);
|
||||
|
||||
// Copy domain UTF-16 after username
|
||||
let domain_offset = username_offset + username_byte_size;
|
||||
let domain_dest = buffer.add(domain_offset);
|
||||
ptr::copy_nonoverlapping(
|
||||
domain_utf16.as_ptr() as *const u8,
|
||||
domain_dest,
|
||||
domain_byte_size,
|
||||
);
|
||||
|
||||
// Set UNICODE_STRING for UserPrincipalName (offset 8 on 64-bit)
|
||||
// Length, MaximumLength, Buffer pointer
|
||||
let upn_ptr = buffer.add(8) as *mut u16;
|
||||
*upn_ptr = ((username_utf16.len() - 1) * 2) as u16; // Length (without null)
|
||||
*(upn_ptr.add(1)) = (username_utf16.len() * 2) as u16; // MaximumLength
|
||||
*((buffer.add(8 + 4)) as *mut *const u8) = username_dest; // Buffer
|
||||
|
||||
// Set UNICODE_STRING for DomainName
|
||||
let dn_offset = 8 + std::mem::size_of::<UnicodeStringRaw>();
|
||||
let dn_ptr = buffer.add(dn_offset) as *mut u16;
|
||||
*dn_ptr = ((domain_utf16.len() - 1) * 2) as u16;
|
||||
*(dn_ptr.add(1)) = (domain_utf16.len() * 2) as u16;
|
||||
*((buffer.add(dn_offset + 4)) as *mut *const u8) = domain_dest;
|
||||
}
|
||||
|
||||
Ok((buffer, total_size))
|
||||
}
|
||||
|
||||
/// Build KERB_S4U_LOGON structure for domain users.
|
||||
fn build_kerb_s4u_logon(username: &str, domain: &str) -> Result<(*mut u8, usize), S4UError> {
|
||||
// Build UPN: username@domain
|
||||
let upn = format!("{}@{}", username, domain);
|
||||
let upn_utf16: Vec<u16> = upn.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let upn_byte_size = upn_utf16.len() * 2;
|
||||
|
||||
let struct_size = std::mem::size_of::<KerbS4ULogonHeader>();
|
||||
let total_size = struct_size + upn_byte_size;
|
||||
|
||||
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
|
||||
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
|
||||
|
||||
if buffer.is_null() {
|
||||
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// MessageType = KERB_S4U_LOGON_TYPE (12)
|
||||
*(buffer as *mut u32) = KERB_S4U_LOGON_TYPE;
|
||||
// Flags = 0
|
||||
*((buffer as *mut u32).add(1)) = 0;
|
||||
|
||||
// Copy UPN UTF-16 after the structure
|
||||
let upn_offset = struct_size;
|
||||
let upn_dest = buffer.add(upn_offset);
|
||||
ptr::copy_nonoverlapping(
|
||||
upn_utf16.as_ptr() as *const u8,
|
||||
upn_dest,
|
||||
upn_byte_size,
|
||||
);
|
||||
|
||||
// Set UNICODE_STRING for ClientUpn (offset 8)
|
||||
let upn_str_ptr = buffer.add(8) as *mut u16;
|
||||
*upn_str_ptr = ((upn_utf16.len() - 1) * 2) as u16;
|
||||
*(upn_str_ptr.add(1)) = (upn_utf16.len() * 2) as u16;
|
||||
*((buffer.add(8 + 4)) as *mut *const u8) = upn_dest;
|
||||
|
||||
// ClientRealm is empty (zeroed)
|
||||
}
|
||||
|
||||
Ok((buffer, total_size))
|
||||
}
|
||||
|
||||
/// Raw UNICODE_STRING layout for size calculation.
|
||||
#[repr(C)]
|
||||
struct UnicodeStringRaw {
|
||||
_length: u16,
|
||||
_maximum_length: u16,
|
||||
_buffer: *const u16,
|
||||
}
|
||||
|
||||
/// Header size for MSV1_0_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
|
||||
#[repr(C)]
|
||||
struct MSV1_0_S4U_LOGON_HEADER {
|
||||
_message_type: u32,
|
||||
_flags: u32,
|
||||
_user_principal_name: UnicodeStringRaw,
|
||||
_domain_name: UnicodeStringRaw,
|
||||
}
|
||||
|
||||
/// Header size for KERB_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
|
||||
#[repr(C)]
|
||||
struct KerbS4ULogonHeader {
|
||||
_message_type: u32,
|
||||
_flags: u32,
|
||||
_client_upn: UnicodeStringRaw,
|
||||
_client_realm: UnicodeStringRaw,
|
||||
}
|
||||
21
client/rdp/server/addr.go
Normal file
21
client/rdp/server/addr.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// parseAddr parses a string into a netip.Addr, stripping any port or zone.
|
||||
func parseAddr(s string) (netip.Addr, error) {
|
||||
// Try as plain IP first
|
||||
if addr, err := netip.ParseAddr(s); err == nil {
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// Try as IP:port
|
||||
if addrPort, err := netip.ParseAddrPort(s); err == nil {
|
||||
return addrPort.Addr(), nil
|
||||
}
|
||||
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address: %s", s)
|
||||
}
|
||||
184
client/rdp/server/pending.go
Normal file
184
client/rdp/server/pending.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultSessionTTL is the default time-to-live for pending RDP sessions.
|
||||
DefaultSessionTTL = 60 * time.Second
|
||||
|
||||
// cleanupInterval is how often the store checks for expired sessions.
|
||||
cleanupInterval = 10 * time.Second
|
||||
|
||||
// nonceLength is the length of the nonce in bytes.
|
||||
nonceLength = 32
|
||||
)
|
||||
|
||||
// PendingRDPSession represents an authorized but not yet consumed RDP session.
|
||||
type PendingRDPSession struct {
|
||||
SessionID string
|
||||
PeerIP netip.Addr
|
||||
OSUsername string
|
||||
Domain string
|
||||
JWTUserID string // for audit trail
|
||||
Nonce string // replay protection
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
consumed bool
|
||||
}
|
||||
|
||||
// PendingStore manages pending RDP session entries with automatic expiration.
|
||||
type PendingStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*PendingRDPSession // keyed by SessionID
|
||||
nonces map[string]struct{} // seen nonces for replay protection
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// NewPendingStore creates a new pending session store with the given TTL.
|
||||
func NewPendingStore(ttl time.Duration) *PendingStore {
|
||||
if ttl <= 0 {
|
||||
ttl = DefaultSessionTTL
|
||||
}
|
||||
return &PendingStore{
|
||||
sessions: make(map[string]*PendingRDPSession),
|
||||
nonces: make(map[string]struct{}),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// Add creates a new pending RDP session and returns it.
|
||||
func (ps *PendingStore) Add(peerIP netip.Addr, osUsername, domain, jwtUserID, nonce string) (*PendingRDPSession, error) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
// Check nonce for replay protection
|
||||
if _, seen := ps.nonces[nonce]; seen {
|
||||
return nil, fmt.Errorf("duplicate nonce: replay detected")
|
||||
}
|
||||
ps.nonces[nonce] = struct{}{}
|
||||
|
||||
now := time.Now()
|
||||
session := &PendingRDPSession{
|
||||
SessionID: uuid.New().String(),
|
||||
PeerIP: peerIP,
|
||||
OSUsername: osUsername,
|
||||
Domain: domain,
|
||||
JWTUserID: jwtUserID,
|
||||
Nonce: nonce,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(ps.ttl),
|
||||
}
|
||||
|
||||
ps.sessions[session.SessionID] = session
|
||||
|
||||
log.Debugf("RDP pending session created: id=%s peer=%s user=%s domain=%s expires=%s",
|
||||
session.SessionID, peerIP, osUsername, domain, session.ExpiresAt.Format(time.RFC3339))
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// QueryByPeerIP finds the first non-consumed, non-expired pending session for the given peer IP.
|
||||
func (ps *PendingStore) QueryByPeerIP(peerIP netip.Addr) (*PendingRDPSession, bool) {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
for _, session := range ps.sessions {
|
||||
if session.PeerIP == peerIP && !session.consumed && now.Before(session.ExpiresAt) {
|
||||
return session, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Consume marks a session as consumed (single-use). Returns true if the session
|
||||
// was found and successfully consumed, false if it was already consumed, expired, or not found.
|
||||
func (ps *PendingStore) Consume(sessionID string) bool {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
session, exists := ps.sessions[sessionID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if session.consumed {
|
||||
log.Debugf("RDP pending session already consumed: id=%s", sessionID)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
log.Debugf("RDP pending session expired: id=%s", sessionID)
|
||||
return false
|
||||
}
|
||||
|
||||
session.consumed = true
|
||||
log.Debugf("RDP pending session consumed: id=%s peer=%s user=%s",
|
||||
sessionID, session.PeerIP, session.OSUsername)
|
||||
return true
|
||||
}
|
||||
|
||||
// StartCleanup runs a background goroutine that periodically removes expired sessions.
|
||||
func (ps *PendingStore) StartCleanup(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
ps.cleanup()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleanup removes expired and consumed sessions.
|
||||
func (ps *PendingStore) cleanup() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range ps.sessions {
|
||||
if now.After(session.ExpiresAt) || session.consumed {
|
||||
delete(ps.sessions, id)
|
||||
delete(ps.nonces, session.Nonce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of active (non-expired, non-consumed) sessions.
|
||||
func (ps *PendingStore) Count() int {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
|
||||
count := 0
|
||||
now := time.Now()
|
||||
for _, session := range ps.sessions {
|
||||
if !session.consumed && now.Before(session.ExpiresAt) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// GenerateNonce creates a cryptographically random nonce for replay protection.
|
||||
func GenerateNonce() (string, error) {
|
||||
b := make([]byte, nonceLength)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
268
client/rdp/server/pending_test.go
Normal file
268
client/rdp/server/pending_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPendingStore_AddAndQuery(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
if session.SessionID == "" {
|
||||
t.Fatal("expected non-empty session ID")
|
||||
}
|
||||
if session.PeerIP != peerIP {
|
||||
t.Errorf("expected peer IP %s, got %s", peerIP, session.PeerIP)
|
||||
}
|
||||
if session.OSUsername != "admin" {
|
||||
t.Errorf("expected username admin, got %s", session.OSUsername)
|
||||
}
|
||||
|
||||
// Query should find the session
|
||||
found, ok := store.QueryByPeerIP(peerIP)
|
||||
if !ok {
|
||||
t.Fatal("expected to find pending session")
|
||||
}
|
||||
if found.SessionID != session.SessionID {
|
||||
t.Errorf("expected session %s, got %s", session.SessionID, found.SessionID)
|
||||
}
|
||||
|
||||
// Query for different IP should not find anything
|
||||
_, ok = store.QueryByPeerIP(netip.MustParseAddr("100.64.0.2"))
|
||||
if ok {
|
||||
t.Fatal("expected no session for different IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_Consume(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-2")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
// First consume should succeed
|
||||
if !store.Consume(session.SessionID) {
|
||||
t.Fatal("expected first consume to succeed")
|
||||
}
|
||||
|
||||
// Second consume should fail (already consumed)
|
||||
if store.Consume(session.SessionID) {
|
||||
t.Fatal("expected second consume to fail")
|
||||
}
|
||||
|
||||
// Query should no longer find consumed session
|
||||
_, ok := store.QueryByPeerIP(peerIP)
|
||||
if ok {
|
||||
t.Fatal("expected consumed session to not be found by query")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_Expiry(t *testing.T) {
|
||||
store := NewPendingStore(50 * time.Millisecond)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-3")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
// Should be found immediately
|
||||
_, ok := store.QueryByPeerIP(peerIP)
|
||||
if !ok {
|
||||
t.Fatal("expected to find session before expiry")
|
||||
}
|
||||
|
||||
// Wait for expiry
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Should not be found after expiry
|
||||
_, ok = store.QueryByPeerIP(peerIP)
|
||||
if ok {
|
||||
t.Fatal("expected session to be expired")
|
||||
}
|
||||
|
||||
// Consume should also fail
|
||||
if store.Consume(session.SessionID) {
|
||||
t.Fatal("expected consume of expired session to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_ReplayProtection(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
|
||||
if err != nil {
|
||||
t.Fatalf("first Add failed: %v", err)
|
||||
}
|
||||
|
||||
// Same nonce should be rejected
|
||||
_, err = store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
|
||||
if err == nil {
|
||||
t.Fatal("expected duplicate nonce to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_Cleanup(t *testing.T) {
|
||||
store := NewPendingStore(50 * time.Millisecond)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-cleanup")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
if store.Count() != 1 {
|
||||
t.Fatalf("expected count 1, got %d", store.Count())
|
||||
}
|
||||
|
||||
// Wait for expiry then trigger cleanup
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
store.cleanup()
|
||||
|
||||
if store.Count() != 0 {
|
||||
t.Fatalf("expected count 0 after cleanup, got %d", store.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_CleanupBackground(t *testing.T) {
|
||||
store := NewPendingStore(50 * time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
store.StartCleanup(ctx)
|
||||
|
||||
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-bg-cleanup")
|
||||
if err != nil {
|
||||
t.Fatalf("Add failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for expiry + cleanup interval
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
_, ok := store.QueryByPeerIP(peerIP)
|
||||
if ok {
|
||||
t.Fatal("expected session to be cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingStore_ConcurrentAccess(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
ip := netip.AddrFrom4([4]byte{100, 64, byte(i / 256), byte(i % 256)})
|
||||
nonce := "nonce-" + string(rune(i+'A'))
|
||||
if i >= 26 {
|
||||
nonce = "nonce-" + string(rune(i-26+'a'))
|
||||
}
|
||||
|
||||
session, err := store.Add(ip, "admin", ".", "user", nonce)
|
||||
if err != nil {
|
||||
return // nonce collision in test is expected
|
||||
}
|
||||
|
||||
store.QueryByPeerIP(ip)
|
||||
store.Consume(session.SessionID)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPendingStore_MultipleSessions(t *testing.T) {
|
||||
store := NewPendingStore(DefaultSessionTTL)
|
||||
|
||||
ip1 := netip.MustParseAddr("100.64.0.1")
|
||||
ip2 := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
s1, err := store.Add(ip1, "admin", ".", "user1", "nonce-a")
|
||||
if err != nil {
|
||||
t.Fatalf("Add s1 failed: %v", err)
|
||||
}
|
||||
|
||||
s2, err := store.Add(ip2, "jdoe", "DOMAIN", "user2", "nonce-b")
|
||||
if err != nil {
|
||||
t.Fatalf("Add s2 failed: %v", err)
|
||||
}
|
||||
|
||||
// Query each
|
||||
found1, ok := store.QueryByPeerIP(ip1)
|
||||
if !ok || found1.SessionID != s1.SessionID {
|
||||
t.Fatal("expected to find s1")
|
||||
}
|
||||
|
||||
found2, ok := store.QueryByPeerIP(ip2)
|
||||
if !ok || found2.SessionID != s2.SessionID {
|
||||
t.Fatal("expected to find s2")
|
||||
}
|
||||
|
||||
if found2.Domain != "DOMAIN" {
|
||||
t.Errorf("expected domain DOMAIN, got %s", found2.Domain)
|
||||
}
|
||||
|
||||
if store.Count() != 2 {
|
||||
t.Errorf("expected count 2, got %d", store.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateNonce(t *testing.T) {
|
||||
nonce1, err := GenerateNonce()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateNonce failed: %v", err)
|
||||
}
|
||||
|
||||
nonce2, err := GenerateNonce()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateNonce failed: %v", err)
|
||||
}
|
||||
|
||||
if len(nonce1) != nonceLength*2 { // hex encoding doubles the length
|
||||
t.Errorf("expected nonce length %d, got %d", nonceLength*2, len(nonce1))
|
||||
}
|
||||
|
||||
if nonce1 == nonce2 {
|
||||
t.Error("expected unique nonces")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWindowsUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedUser string
|
||||
expectedDomain string
|
||||
}{
|
||||
{"admin", "admin", "."},
|
||||
{"DOMAIN\\admin", "admin", "DOMAIN"},
|
||||
{"admin@domain.com", "admin", "domain.com"},
|
||||
{".\\localuser", "localuser", "."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
user, domain := parseWindowsUsername(tt.input)
|
||||
if user != tt.expectedUser {
|
||||
t.Errorf("parseWindowsUsername(%q) user = %q, want %q", tt.input, user, tt.expectedUser)
|
||||
}
|
||||
if domain != tt.expectedDomain {
|
||||
t.Errorf("parseWindowsUsername(%q) domain = %q, want %q", tt.input, domain, tt.expectedDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
19
client/rdp/server/pipe_stub.go
Normal file
19
client/rdp/server/pipe_stub.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import "context"
|
||||
|
||||
type stubPipeServer struct{}
|
||||
|
||||
func newPipeServer(_ *PendingStore) PipeServer {
|
||||
return &stubPipeServer{}
|
||||
}
|
||||
|
||||
func (s *stubPipeServer) Start(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubPipeServer) Stop() error {
|
||||
return nil
|
||||
}
|
||||
164
client/rdp/server/pipe_windows.go
Normal file
164
client/rdp/server/pipe_windows.go
Normal file
@@ -0,0 +1,164 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
)
|
||||
|
||||
const (
|
||||
// PipeName is the named pipe path used for IPC between the NetBird agent and
|
||||
// the Credential Provider DLL.
|
||||
PipeName = `\\.\pipe\netbird-rdp-auth`
|
||||
|
||||
// pipeSDDL restricts access to LOCAL_SYSTEM (SY) and Administrators (BA).
|
||||
pipeSDDL = "D:P(A;;GA;;;SY)(A;;GA;;;BA)"
|
||||
|
||||
// maxPipeRequestSize is the maximum size of a pipe request in bytes.
|
||||
maxPipeRequestSize = 4096
|
||||
)
|
||||
|
||||
// windowsPipeServer implements the PipeServer interface for Windows.
|
||||
type windowsPipeServer struct {
|
||||
pending *PendingStore
|
||||
listener net.Listener
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newPipeServer(pending *PendingStore) PipeServer {
|
||||
return &windowsPipeServer{
|
||||
pending: pending,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) Start(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
|
||||
cfg := &winio.PipeConfig{
|
||||
SecurityDescriptor: pipeSDDL,
|
||||
}
|
||||
|
||||
listener, err := winio.ListenPipe(PipeName, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
go s.acceptLoop()
|
||||
|
||||
log.Infof("RDP named pipe server started on %s", PipeName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("RDP pipe accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handlePipeConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) handlePipeConnection(conn net.Conn) {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("RDP pipe close: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(conn, maxPipeRequestSize))
|
||||
if err != nil {
|
||||
log.Debugf("RDP pipe read: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var req PipeRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
log.Debugf("RDP pipe unmarshal: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var resp PipeResponse
|
||||
|
||||
switch req.Action {
|
||||
case PipeActionQuery:
|
||||
resp = s.handleQuery(req.RemoteIP)
|
||||
case PipeActionConsume:
|
||||
resp = s.handleConsume(req.SessionID)
|
||||
default:
|
||||
log.Debugf("RDP pipe unknown action: %s", req.Action)
|
||||
return
|
||||
}
|
||||
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Debugf("RDP pipe marshal response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := conn.Write(respData); err != nil {
|
||||
log.Debugf("RDP pipe write response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) handleQuery(remoteIP string) PipeResponse {
|
||||
peerIP, err := parseAddr(remoteIP)
|
||||
if err != nil {
|
||||
log.Debugf("RDP pipe invalid remote IP: %s", remoteIP)
|
||||
return PipeResponse{Found: false}
|
||||
}
|
||||
|
||||
session, found := s.pending.QueryByPeerIP(peerIP)
|
||||
if !found {
|
||||
return PipeResponse{Found: false}
|
||||
}
|
||||
|
||||
return PipeResponse{
|
||||
Found: true,
|
||||
SessionID: session.SessionID,
|
||||
OSUser: session.OSUsername,
|
||||
Domain: session.Domain,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *windowsPipeServer) handleConsume(sessionID string) PipeResponse {
|
||||
if s.pending.Consume(sessionID) {
|
||||
return PipeResponse{Found: true, SessionID: sessionID}
|
||||
}
|
||||
return PipeResponse{Found: false}
|
||||
}
|
||||
48
client/rdp/server/protocol.go
Normal file
48
client/rdp/server/protocol.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package server
|
||||
|
||||
// AuthRequest is the sideband authorization request sent by the connecting peer
|
||||
// to the target peer's RDP auth server over the WireGuard tunnel.
|
||||
type AuthRequest struct {
|
||||
JWTToken string `json:"jwt_token"`
|
||||
RequestedUser string `json:"requested_user"`
|
||||
ClientPeerIP string `json:"client_peer_ip"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// AuthResponse is the sideband authorization response sent by the target peer
|
||||
// back to the connecting peer.
|
||||
type AuthResponse struct {
|
||||
Status string `json:"status"` // "authorized" or "denied"
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"` // unix timestamp
|
||||
OSUser string `json:"os_user,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// PipeRequest is the IPC request from the Credential Provider DLL to the NetBird agent
|
||||
// via the named pipe.
|
||||
type PipeRequest struct {
|
||||
Action string `json:"action"` // "query_pending" or "consume"
|
||||
RemoteIP string `json:"remote_ip"` // connecting peer's WG IP
|
||||
SessionID string `json:"session_id,omitempty"` // for consume action
|
||||
}
|
||||
|
||||
// PipeResponse is the IPC response from the NetBird agent to the Credential Provider DLL.
|
||||
type PipeResponse struct {
|
||||
Found bool `json:"found"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
OSUser string `json:"os_user,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// StatusAuthorized indicates the RDP session was authorized.
|
||||
StatusAuthorized = "authorized"
|
||||
// StatusDenied indicates the RDP session was denied.
|
||||
StatusDenied = "denied"
|
||||
|
||||
// PipeActionQuery queries for a pending session by remote IP.
|
||||
PipeActionQuery = "query_pending"
|
||||
// PipeActionConsume marks a pending session as consumed.
|
||||
PipeActionConsume = "consume"
|
||||
)
|
||||
286
client/rdp/server/server.go
Normal file
286
client/rdp/server/server.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// InternalRDPAuthPort is the internal port the sideband auth server listens on.
|
||||
InternalRDPAuthPort = 22023
|
||||
|
||||
// DefaultRDPAuthPort is the external port on the WireGuard interface (DNAT target).
|
||||
DefaultRDPAuthPort = 3390
|
||||
|
||||
// maxRequestSize is the maximum size of an auth request in bytes.
|
||||
maxRequestSize = 64 * 1024
|
||||
|
||||
// connectionTimeout is the timeout for a single auth connection.
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// JWTValidator validates JWT tokens and extracts user identity.
|
||||
type JWTValidator interface {
|
||||
ValidateAndExtract(token string) (userID string, err error)
|
||||
}
|
||||
|
||||
// Authorizer checks if a user is authorized for RDP access.
|
||||
type Authorizer interface {
|
||||
Authorize(jwtUserID, osUsername string) (string, error)
|
||||
}
|
||||
|
||||
// Server is the sideband RDP authorization server that listens on the WireGuard interface.
|
||||
type Server struct {
|
||||
listener net.Listener
|
||||
pending *PendingStore
|
||||
pipeServer PipeServer
|
||||
jwtValidator JWTValidator
|
||||
authorizer Authorizer
|
||||
networkAddr netip.Prefix // WireGuard network for source IP validation
|
||||
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// PipeServer is the interface for the named pipe IPC server (platform-specific).
|
||||
type PipeServer interface {
|
||||
Start(ctx context.Context) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// Config holds the configuration for the RDP auth server.
|
||||
type Config struct {
|
||||
JWTValidator JWTValidator
|
||||
Authorizer Authorizer
|
||||
NetworkAddr netip.Prefix
|
||||
SessionTTL time.Duration
|
||||
}
|
||||
|
||||
// New creates a new RDP sideband auth server.
|
||||
func New(cfg *Config) *Server {
|
||||
ttl := cfg.SessionTTL
|
||||
if ttl <= 0 {
|
||||
ttl = DefaultSessionTTL
|
||||
}
|
||||
|
||||
pending := NewPendingStore(ttl)
|
||||
|
||||
return &Server{
|
||||
pending: pending,
|
||||
pipeServer: newPipeServer(pending),
|
||||
jwtValidator: cfg.JWTValidator,
|
||||
authorizer: cfg.Authorizer,
|
||||
networkAddr: cfg.NetworkAddr,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening for sideband auth requests on the given address.
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return errors.New("RDP auth server already running")
|
||||
}
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
|
||||
listenAddr := net.TCPAddrFromAddrPort(addr)
|
||||
listener, err := net.ListenTCP("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
s.pending.StartCleanup(s.ctx)
|
||||
|
||||
if s.pipeServer != nil {
|
||||
if err := s.pipeServer.Start(s.ctx); err != nil {
|
||||
log.Warnf("failed to start RDP named pipe server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
go s.acceptLoop()
|
||||
|
||||
log.Infof("RDP sideband auth server started on %s", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shuts down the server and cleans up resources.
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
if s.pipeServer != nil {
|
||||
if err := s.pipeServer.Stop(); err != nil {
|
||||
log.Warnf("failed to stop RDP named pipe server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("close listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("RDP sideband auth server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPendingStore returns the pending session store (for testing/named pipe access).
|
||||
func (s *Server) GetPendingStore() *PendingStore {
|
||||
return s.pending
|
||||
}
|
||||
|
||||
func (s *Server) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("RDP auth accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("RDP auth close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(connectionTimeout)); err != nil {
|
||||
log.Debugf("RDP auth set deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate source IP is from WireGuard network
|
||||
remoteAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
log.Debugf("RDP auth parse remote addr: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.networkAddr.Contains(remoteAddr.Addr()) {
|
||||
log.Warnf("RDP auth rejected connection from non-WG address: %s", remoteAddr.Addr())
|
||||
return
|
||||
}
|
||||
|
||||
// Read request
|
||||
data, err := io.ReadAll(io.LimitReader(conn, maxRequestSize))
|
||||
if err != nil {
|
||||
log.Debugf("RDP auth read request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var req AuthRequest
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
log.Debugf("RDP auth unmarshal request: %v", err)
|
||||
s.sendResponse(conn, &AuthResponse{Status: StatusDenied, Reason: "invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
response := s.processAuthRequest(remoteAddr.Addr(), &req)
|
||||
s.sendResponse(conn, response)
|
||||
}
|
||||
|
||||
func (s *Server) processAuthRequest(peerIP netip.Addr, req *AuthRequest) *AuthResponse {
|
||||
// Validate JWT
|
||||
if s.jwtValidator == nil {
|
||||
// No JWT validation configured - for POC, accept all requests from WG peers
|
||||
log.Warnf("RDP auth: no JWT validator configured, accepting request from %s", peerIP)
|
||||
return s.createSession(peerIP, req, "no-jwt-validation")
|
||||
}
|
||||
|
||||
userID, err := s.jwtValidator.ValidateAndExtract(req.JWTToken)
|
||||
if err != nil {
|
||||
log.Warnf("RDP auth JWT validation failed for %s: %v", peerIP, err)
|
||||
return &AuthResponse{Status: StatusDenied, Reason: "JWT validation failed"}
|
||||
}
|
||||
|
||||
// Check authorization
|
||||
if s.authorizer != nil {
|
||||
if _, err := s.authorizer.Authorize(userID, req.RequestedUser); err != nil {
|
||||
log.Warnf("RDP auth denied for user %s -> %s: %v", userID, req.RequestedUser, err)
|
||||
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
|
||||
}
|
||||
}
|
||||
|
||||
return s.createSession(peerIP, req, userID)
|
||||
}
|
||||
|
||||
func (s *Server) createSession(peerIP netip.Addr, req *AuthRequest, jwtUserID string) *AuthResponse {
|
||||
// Parse domain from requested user (DOMAIN\user or user@domain)
|
||||
osUser, domain := parseWindowsUsername(req.RequestedUser)
|
||||
|
||||
session, err := s.pending.Add(peerIP, osUser, domain, jwtUserID, req.Nonce)
|
||||
if err != nil {
|
||||
log.Warnf("RDP auth create session failed: %v", err)
|
||||
return &AuthResponse{Status: StatusDenied, Reason: err.Error()}
|
||||
}
|
||||
|
||||
return &AuthResponse{
|
||||
Status: StatusAuthorized,
|
||||
SessionID: session.SessionID,
|
||||
ExpiresAt: session.ExpiresAt.Unix(),
|
||||
OSUser: session.OSUsername,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sendResponse(conn net.Conn, resp *AuthResponse) {
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Debugf("RDP auth marshal response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
log.Debugf("RDP auth write response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// parseWindowsUsername extracts username and domain from Windows username formats.
|
||||
// Supports DOMAIN\username, username@domain, and plain username.
|
||||
func parseWindowsUsername(fullUsername string) (username, domain string) {
|
||||
for i := len(fullUsername) - 1; i >= 0; i-- {
|
||||
if fullUsername[i] == '\\' {
|
||||
return fullUsername[i+1:], fullUsername[:i]
|
||||
}
|
||||
}
|
||||
|
||||
if idx := indexOf(fullUsername, '@'); idx != -1 {
|
||||
return fullUsername[:idx], fullUsername[idx+1:]
|
||||
}
|
||||
|
||||
return fullUsername, "."
|
||||
}
|
||||
|
||||
func indexOf(s string, c byte) int {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == c {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
Reference in New Issue
Block a user