mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Compare commits
3 Commits
fix/androi
...
claude/rdp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b621a2628b | ||
|
|
4949ca6194 | ||
|
|
c5186f1483 |
276
client/cmd/rdp.go
Normal file
276
client/cmd/rdp.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"os/user"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
rdpclient "github.com/netbirdio/netbird/client/rdp/client"
|
||||||
|
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
serverRDPAllowedFlag = "allow-server-rdp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
rdpUsername string
|
||||||
|
rdpHost string
|
||||||
|
rdpNoBrowser bool
|
||||||
|
rdpNoCache bool
|
||||||
|
serverRDPAllowed bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rdpCmd.PersistentFlags().StringVarP(&rdpUsername, "user", "u", "", "Windows username on remote peer")
|
||||||
|
rdpCmd.PersistentFlags().BoolVar(&rdpNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
rdpCmd.PersistentFlags().BoolVar(&rdpNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&serverRDPAllowed, serverRDPAllowedFlag, false, "Allow RDP passthrough on peer (passwordless RDP via credential provider)")
|
||||||
|
}
|
||||||
|
|
||||||
|
var rdpCmd = &cobra.Command{
|
||||||
|
Use: "rdp [flags] [user@]host",
|
||||||
|
Short: "Connect to a NetBird peer via RDP (passwordless)",
|
||||||
|
Long: `Connect to a NetBird peer using Remote Desktop Protocol with token-based
|
||||||
|
passwordless authentication. The target peer must have RDP passthrough enabled.
|
||||||
|
|
||||||
|
This command:
|
||||||
|
1. Obtains a JWT token via OIDC authentication
|
||||||
|
2. Sends the token to the target peer's sideband auth service
|
||||||
|
3. If authorized, launches mstsc.exe to connect
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
netbird rdp peer-hostname
|
||||||
|
netbird rdp administrator@peer-hostname
|
||||||
|
netbird rdp --user admin peer-hostname`,
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: rdpFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func rdpFn(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
SetFlagsFromEnvVars(cmd)
|
||||||
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
|
logOutput := "console"
|
||||||
|
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||||
|
logOutput = firstLogFile
|
||||||
|
}
|
||||||
|
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse user@host
|
||||||
|
if err := parseRDPHostArg(args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
rdpCtx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
if err := runRDP(rdpCtx, cmd); err != nil {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-sig:
|
||||||
|
cancel()
|
||||||
|
<-rdpCtx.Done()
|
||||||
|
return nil
|
||||||
|
case err := <-errCh:
|
||||||
|
return err
|
||||||
|
case <-rdpCtx.Done():
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRDPHostArg(arg string) error {
|
||||||
|
if strings.Contains(arg, "@") {
|
||||||
|
parts := strings.SplitN(arg, "@", 2)
|
||||||
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||||
|
return errors.New("invalid user@host format")
|
||||||
|
}
|
||||||
|
if rdpUsername == "" {
|
||||||
|
rdpUsername = parts[0]
|
||||||
|
}
|
||||||
|
rdpHost = parts[1]
|
||||||
|
} else {
|
||||||
|
rdpHost = arg
|
||||||
|
}
|
||||||
|
|
||||||
|
if rdpUsername == "" {
|
||||||
|
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||||
|
rdpUsername = sudoUser
|
||||||
|
} else if currentUser, err := user.Current(); err == nil {
|
||||||
|
rdpUsername = currentUser.Username
|
||||||
|
} else {
|
||||||
|
rdpUsername = "Administrator"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runRDP(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
// Connect to daemon
|
||||||
|
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||||
|
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to daemon: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = grpcConn.Close() }()
|
||||||
|
|
||||||
|
daemonClient := proto.NewDaemonServiceClient(grpcConn)
|
||||||
|
|
||||||
|
// Resolve peer IP
|
||||||
|
peerIP, err := resolvePeerIP(ctx, daemonClient, rdpHost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("resolve peer %s: %w", rdpHost, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Connecting to %s@%s (%s)...\n", rdpUsername, rdpHost, peerIP)
|
||||||
|
|
||||||
|
// Obtain JWT token
|
||||||
|
hint := profilemanager.GetLoginHint()
|
||||||
|
var browserOpener func(string) error
|
||||||
|
if !rdpNoBrowser {
|
||||||
|
browserOpener = util.OpenBrowser
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtToken, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !rdpNoCache, hint, browserOpener)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("JWT authentication: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("JWT authentication successful")
|
||||||
|
cmd.Println("Authenticated. Requesting RDP access...")
|
||||||
|
|
||||||
|
// Generate nonce for replay protection
|
||||||
|
nonce, err := rdpserver.GenerateNonce()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sideband auth request
|
||||||
|
authClient := rdpclient.New()
|
||||||
|
authAddr := net.JoinHostPort(peerIP, fmt.Sprintf("%d", rdpserver.DefaultRDPAuthPort))
|
||||||
|
|
||||||
|
resp, err := authClient.RequestAuth(ctx, authAddr, &rdpserver.AuthRequest{
|
||||||
|
JWTToken: jwtToken,
|
||||||
|
RequestedUser: rdpUsername,
|
||||||
|
ClientPeerIP: "", // will be filled by the server from the connection
|
||||||
|
Nonce: nonce,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
cmd.Printf("Failed to authorize RDP session with %s\n", rdpHost)
|
||||||
|
cmd.Printf("\nTroubleshooting:\n")
|
||||||
|
cmd.Printf(" 1. Check connectivity: netbird status -d\n")
|
||||||
|
cmd.Printf(" 2. Verify RDP passthrough is enabled on the target peer\n")
|
||||||
|
return fmt.Errorf("sideband auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Status != rdpserver.StatusAuthorized {
|
||||||
|
return fmt.Errorf("RDP access denied: %s", resp.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("RDP access authorized (session: %s, user: %s)\n", resp.SessionID, resp.OSUser)
|
||||||
|
cmd.Printf("Launching Remote Desktop client...\n")
|
||||||
|
|
||||||
|
// Launch mstsc.exe (platform-specific)
|
||||||
|
if err := launchRDPClient(peerIP); err != nil {
|
||||||
|
return fmt.Errorf("launch RDP client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvePeerIP resolves a peer hostname/FQDN to its WireGuard IP address
|
||||||
|
// by querying the daemon for the current peer status.
|
||||||
|
func resolvePeerIP(ctx context.Context, client proto.DaemonServiceClient, peerAddress string) (string, error) {
|
||||||
|
statusResp, err := client.Status(ctx, &proto.StatusRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get daemon status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusResp.GetFullStatus() == nil {
|
||||||
|
return "", errors.New("daemon returned empty status")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range statusResp.GetFullStatus().GetPeers() {
|
||||||
|
if matchesPeer(peer, peerAddress) {
|
||||||
|
ip := peer.GetIP()
|
||||||
|
if ip == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Strip CIDR suffix if present
|
||||||
|
if idx := strings.Index(ip, "/"); idx != -1 {
|
||||||
|
ip = ip[:idx]
|
||||||
|
}
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not found as a peer name, try as a direct IP
|
||||||
|
if addr, err := net.ResolveIPAddr("ip", peerAddress); err == nil {
|
||||||
|
return addr.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("peer %q not found in network", peerAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchesPeer(peer *proto.PeerState, address string) bool {
|
||||||
|
address = strings.ToLower(address)
|
||||||
|
|
||||||
|
if strings.EqualFold(peer.GetFqdn(), address) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match against FQDN without trailing dot
|
||||||
|
fqdn := strings.TrimSuffix(peer.GetFqdn(), ".")
|
||||||
|
if strings.EqualFold(fqdn, address) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match against short hostname (first part of FQDN)
|
||||||
|
if parts := strings.SplitN(fqdn, ".", 2); len(parts) > 0 {
|
||||||
|
if strings.EqualFold(parts[0], address) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match against IP
|
||||||
|
ip := peer.GetIP()
|
||||||
|
if idx := strings.Index(ip, "/"); idx != -1 {
|
||||||
|
ip = ip[:idx]
|
||||||
|
}
|
||||||
|
if ip == address {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
13
client/cmd/rdp_stub.go
Normal file
13
client/cmd/rdp_stub.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// launchRDPClient is a stub for non-Windows platforms.
|
||||||
|
func launchRDPClient(peerIP string) error {
|
||||||
|
fmt.Printf("RDP session authorized for %s\n", peerIP)
|
||||||
|
fmt.Println("Note: mstsc.exe is only available on Windows.")
|
||||||
|
fmt.Printf("Use any RDP client to connect to %s:3389\n", peerIP)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
34
client/cmd/rdp_windows.go
Normal file
34
client/cmd/rdp_windows.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// launchRDPClient launches the native Windows Remote Desktop client (mstsc.exe).
|
||||||
|
func launchRDPClient(peerIP string) error {
|
||||||
|
mstscPath, err := exec.LookPath("mstsc.exe")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("mstsc.exe not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(mstscPath, fmt.Sprintf("/v:%s", peerIP))
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("start mstsc.exe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("launched mstsc.exe (PID %d) connecting to %s", cmd.Process.Pid, peerIP)
|
||||||
|
|
||||||
|
// Don't wait for mstsc to exit - it runs independently
|
||||||
|
go func() {
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
log.Debugf("mstsc.exe exited: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -150,6 +150,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(logoutCmd)
|
rootCmd.AddCommand(logoutCmd)
|
||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
|
rootCmd.AddCommand(rdpCmd)
|
||||||
rootCmd.AddCommand(networksCMD)
|
rootCmd.AddCommand(networksCMD)
|
||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|||||||
@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
req.ServerSSHAllowed = &serverSSHAllowed
|
req.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverRDPAllowedFlag).Changed {
|
||||||
|
req.ServerRDPAllowed = &serverRDPAllowed
|
||||||
|
}
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
req.EnableSSHRoot = &enableSSHRoot
|
req.EnableSSHRoot = &enableSSHRoot
|
||||||
}
|
}
|
||||||
@@ -458,6 +461,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverRDPAllowedFlag).Changed {
|
||||||
|
ic.ServerRDPAllowed = &serverRDPAllowed
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
ic.EnableSSHRoot = &enableSSHRoot
|
ic.EnableSSHRoot = &enableSSHRoot
|
||||||
@@ -582,6 +588,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverRDPAllowedFlag).Changed {
|
||||||
|
loginRequest.ServerRDPAllowed = &serverRDPAllowed
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||||
|
|||||||
@@ -543,6 +543,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
RosenpassEnabled: config.RosenpassEnabled,
|
RosenpassEnabled: config.RosenpassEnabled,
|
||||||
RosenpassPermissive: config.RosenpassPermissive,
|
RosenpassPermissive: config.RosenpassPermissive,
|
||||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||||
|
ServerRDPAllowed: config.ServerRDPAllowed != nil && *config.ServerRDPAllowed,
|
||||||
EnableSSHRoot: config.EnableSSHRoot,
|
EnableSSHRoot: config.EnableSSHRoot,
|
||||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ type EngineConfig struct {
|
|||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
|
|
||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
ServerRDPAllowed bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
@@ -197,6 +198,7 @@ type Engine struct {
|
|||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServer sshServer
|
sshServer sshServer
|
||||||
|
rdpServer rdpServer
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
@@ -1036,6 +1038,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := e.updateRDP(); err != nil {
|
||||||
|
log.Warnf("failed handling RDP server setup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
@@ -1323,6 +1329,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||||
|
|
||||||
|
// Reuse SSH ACL for RDP authorization
|
||||||
|
e.updateRDPServerAuth(networkMap.GetSshAuth())
|
||||||
}
|
}
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
|
|||||||
191
client/internal/engine_rdp.go
Normal file
191
client/internal/engine_rdp.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rdpServer interface {
|
||||||
|
Start(ctx context.Context, addr netip.AddrPort) error
|
||||||
|
Stop() error
|
||||||
|
GetPendingStore() *rdpserver.PendingStore
|
||||||
|
UpdateRDPAuth(config *sshauth.Config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) setupRDPPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
|
||||||
|
return fmt.Errorf("add RDP auth port redirection: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("RDP auth port redirection enabled: %s:%d -> %s:%d",
|
||||||
|
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) cleanupRDPPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
|
||||||
|
return fmt.Errorf("remove RDP auth port redirection: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("RDP auth port redirection removed: %s:%d -> %s:%d",
|
||||||
|
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRDP handles starting/stopping the RDP server based on the config flag.
|
||||||
|
func (e *Engine) updateRDP() error {
|
||||||
|
if !e.config.ServerRDPAllowed {
|
||||||
|
if e.rdpServer != nil {
|
||||||
|
log.Info("RDP passthrough disabled, stopping RDP auth server")
|
||||||
|
}
|
||||||
|
return e.stopRDPServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.BlockInbound {
|
||||||
|
log.Info("RDP server is disabled because inbound connections are blocked")
|
||||||
|
return e.stopRDPServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.rdpServer != nil {
|
||||||
|
log.Debug("RDP auth server is already running")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.startRDPServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) startRDPServer() error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wg interface not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
wgAddr := e.wgInterface.Address()
|
||||||
|
|
||||||
|
cfg := &rdpserver.Config{
|
||||||
|
NetworkAddr: wgAddr.Network,
|
||||||
|
}
|
||||||
|
|
||||||
|
server := rdpserver.New(cfg)
|
||||||
|
|
||||||
|
netbirdIP := wgAddr.IP
|
||||||
|
listenAddr := netip.AddrPortFrom(netbirdIP, rdpserver.InternalRDPAuthPort)
|
||||||
|
|
||||||
|
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||||
|
return fmt.Errorf("start RDP auth server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.rdpServer = server
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.RegisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
|
||||||
|
log.Debugf("registered RDP auth service with netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.setupRDPPortRedirection(); err != nil {
|
||||||
|
log.Warnf("failed to setup RDP auth port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the credential provider DLL dynamically (Windows only)
|
||||||
|
if err := rdpserver.RegisterCredentialProvider(); err != nil {
|
||||||
|
log.Warnf("failed to register RDP credential provider (passwordless RDP will not work): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("RDP passthrough enabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) stopRDPServer() error {
|
||||||
|
if e.rdpServer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.cleanupRDPPortRedirection(); err != nil {
|
||||||
|
log.Warnf("failed to cleanup RDP auth port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.UnregisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
|
||||||
|
log.Debugf("unregistered RDP auth service from netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister the credential provider DLL (Windows only)
|
||||||
|
if err := rdpserver.UnregisterCredentialProvider(); err != nil {
|
||||||
|
log.Warnf("failed to unregister RDP credential provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("stopping RDP auth server")
|
||||||
|
err := e.rdpServer.Stop()
|
||||||
|
e.rdpServer = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stop: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateRDPServerAuth reuses the SSH authorization config for RDP access control.
|
||||||
|
// This means the same user/machine-user mappings that control SSH access also control RDP.
|
||||||
|
func (e *Engine) updateRDPServerAuth(sshAuth *mgmProto.SSHAuth) {
|
||||||
|
if sshAuth == nil || e.rdpServer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
protoUsers := sshAuth.GetAuthorizedUsers()
|
||||||
|
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||||
|
for i, hash := range protoUsers {
|
||||||
|
if len(hash) != 16 {
|
||||||
|
log.Warnf("invalid hash length %d, expected 16 - skipping RDP server auth update", len(hash))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
machineUsers := make(map[string][]uint32)
|
||||||
|
for osUser, indexes := range sshAuth.GetMachineUsers() {
|
||||||
|
machineUsers[osUser] = indexes.GetIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
authConfig := &sshauth.Config{
|
||||||
|
UserIDClaim: sshAuth.GetUserIDClaim(),
|
||||||
|
AuthorizedUsers: authorizedUsers,
|
||||||
|
MachineUsers: machineUsers,
|
||||||
|
}
|
||||||
|
|
||||||
|
e.rdpServer.UpdateRDPAuth(authConfig)
|
||||||
|
}
|
||||||
@@ -64,6 +64,7 @@ type ConfigInput struct {
|
|||||||
StateFilePath string
|
StateFilePath string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
ServerRDPAllowed *bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
@@ -114,6 +115,7 @@ type Config struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
ServerRDPAllowed *bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
@@ -415,6 +417,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.ServerRDPAllowed != nil {
|
||||||
|
if config.ServerRDPAllowed == nil || *input.ServerRDPAllowed != *config.ServerRDPAllowed {
|
||||||
|
if *input.ServerRDPAllowed {
|
||||||
|
log.Infof("enabling RDP passthrough")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling RDP passthrough")
|
||||||
|
}
|
||||||
|
config.ServerRDPAllowed = input.ServerRDPAllowed
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
} else if config.ServerRDPAllowed == nil {
|
||||||
|
config.ServerRDPAllowed = util.False()
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||||
if *input.EnableSSHRoot {
|
if *input.EnableSSHRoot {
|
||||||
log.Infof("enabling SSH root login")
|
log.Infof("enabling SSH root login")
|
||||||
|
|||||||
@@ -472,6 +472,7 @@ type LoginRequest struct {
|
|||||||
EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||||
DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
||||||
SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
|
SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
|
||||||
|
ServerRDPAllowed *bool `protobuf:"varint,40,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -780,6 +781,13 @@ func (x *LoginRequest) GetSshJWTCacheTTL() int32 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *LoginRequest) GetServerRDPAllowed() bool {
|
||||||
|
if x != nil && x.ServerRDPAllowed != nil {
|
||||||
|
return *x.ServerRDPAllowed
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type LoginResponse struct {
|
type LoginResponse struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
|
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
|
||||||
@@ -1312,6 +1320,7 @@ type GetConfigResponse struct {
|
|||||||
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
|
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||||
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
|
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
|
||||||
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
|
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
|
||||||
|
ServerRDPAllowed bool `protobuf:"varint,27,opt,name=serverRDPAllowed,proto3" json:"serverRDPAllowed,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -1528,6 +1537,13 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *GetConfigResponse) GetServerRDPAllowed() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.ServerRDPAllowed
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
type PeerState struct {
|
type PeerState struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
@@ -4139,6 +4155,7 @@ type SetConfigRequest struct {
|
|||||||
EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||||
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
||||||
SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
|
SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
|
||||||
|
ServerRDPAllowed *bool `protobuf:"varint,35,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -4411,6 +4428,13 @@ func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *SetConfigRequest) GetServerRDPAllowed() bool {
|
||||||
|
if x != nil && x.ServerRDPAllowed != nil {
|
||||||
|
return *x.ServerRDPAllowed
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type SetConfigResponse struct {
|
type SetConfigResponse struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
|
|||||||
@@ -209,6 +209,8 @@ message LoginRequest {
|
|||||||
optional bool enableSSHRemotePortForwarding = 37;
|
optional bool enableSSHRemotePortForwarding = 37;
|
||||||
optional bool disableSSHAuth = 38;
|
optional bool disableSSHAuth = 38;
|
||||||
optional int32 sshJWTCacheTTL = 39;
|
optional int32 sshJWTCacheTTL = 39;
|
||||||
|
|
||||||
|
optional bool serverRDPAllowed = 40;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@@ -316,6 +318,8 @@ message GetConfigResponse {
|
|||||||
bool disableSSHAuth = 25;
|
bool disableSSHAuth = 25;
|
||||||
|
|
||||||
int32 sshJWTCacheTTL = 26;
|
int32 sshJWTCacheTTL = 26;
|
||||||
|
|
||||||
|
bool serverRDPAllowed = 27;
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
@@ -677,6 +681,8 @@ message SetConfigRequest {
|
|||||||
optional bool enableSSHRemotePortForwarding = 32;
|
optional bool enableSSHRemotePortForwarding = 32;
|
||||||
optional bool disableSSHAuth = 33;
|
optional bool disableSSHAuth = 33;
|
||||||
optional int32 sshJWTCacheTTL = 34;
|
optional int32 sshJWTCacheTTL = 34;
|
||||||
|
|
||||||
|
optional bool serverRDPAllowed = 35;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SetConfigResponse{}
|
message SetConfigResponse{}
|
||||||
|
|||||||
88
client/rdp/client/client.go
Normal file
88
client/rdp/client/client.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultTimeout is the default timeout for sideband auth requests.
|
||||||
|
DefaultTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
// maxResponseSize is the maximum size of an auth response in bytes.
|
||||||
|
maxResponseSize = 64 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client connects to a target peer's RDP sideband auth server to request access.
|
||||||
|
type Client struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new sideband RDP auth client.
|
||||||
|
func New() *Client {
|
||||||
|
return &Client{
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestAuth sends an authorization request to the target peer's sideband server
|
||||||
|
// and returns the response. The addr should be in "host:port" format.
|
||||||
|
func (c *Client) RequestAuth(ctx context.Context, addr string, req *rdpserver.AuthRequest) (*rdpserver.AuthResponse, error) {
|
||||||
|
timeout := c.Timeout
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = DefaultTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connect to RDP auth server at %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
deadline, ok := ctx.Deadline()
|
||||||
|
if ok {
|
||||||
|
if err := conn.SetDeadline(deadline); err != nil {
|
||||||
|
return nil, fmt.Errorf("set connection deadline: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
reqData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal auth request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write(reqData); err != nil {
|
||||||
|
return nil, fmt.Errorf("send auth request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal we're done writing so the server can read the full request
|
||||||
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||||
|
if err := tcpConn.CloseWrite(); err != nil {
|
||||||
|
return nil, fmt.Errorf("close write: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
respData, err := io.ReadAll(io.LimitReader(conn, maxResponseSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read auth response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp rdpserver.AuthResponse
|
||||||
|
if err := json.Unmarshal(respData, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal auth response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
31
client/rdp/credprov/Cargo.toml
Normal file
31
client/rdp/credprov/Cargo.toml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
[package]
|
||||||
|
name = "netbird-credprov"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
description = "NetBird RDP Credential Provider for Windows"
|
||||||
|
license = "BSD-3-Clause"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
windows = { version = "0.58", features = [
|
||||||
|
"implement",
|
||||||
|
"Win32_Foundation",
|
||||||
|
"Win32_System_Com",
|
||||||
|
"Win32_UI_Shell",
|
||||||
|
"Win32_Security",
|
||||||
|
"Win32_Security_Authentication_Identity",
|
||||||
|
"Win32_Security_Credentials",
|
||||||
|
"Win32_System_RemoteDesktop",
|
||||||
|
"Win32_System_Threading",
|
||||||
|
] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
uuid = { version = "1", features = ["v4"] }
|
||||||
|
log = "0.4"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
opt-level = "s"
|
||||||
|
lto = true
|
||||||
|
strip = true
|
||||||
210
client/rdp/credprov/src/credential.rs
Normal file
210
client/rdp/credprov/src/credential.rs
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
//! ICredentialProviderCredential implementation.
|
||||||
|
//!
|
||||||
|
//! Represents a single "NetBird Login" credential tile on the Windows login screen.
|
||||||
|
//! When selected, it queries the local NetBird agent for pending RDP sessions and
|
||||||
|
//! performs S4U logon to authenticate the user without a password.
|
||||||
|
|
||||||
|
use crate::named_pipe_client::{NamedPipeClient, PipeResponse};
|
||||||
|
use crate::s4u;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use windows::core::*;
|
||||||
|
use windows::Win32::Foundation::*;
|
||||||
|
use windows::Win32::Security::Credentials::*;
|
||||||
|
use windows::Win32::UI::Shell::*;
|
||||||
|
|
||||||
|
/// NetBird credential tile that appears on the Windows login screen.
|
||||||
|
#[implement(ICredentialProviderCredential)]
|
||||||
|
pub struct NetBirdCredential {
|
||||||
|
/// The pending session information from the NetBird agent.
|
||||||
|
session: Mutex<Option<PipeResponse>>,
|
||||||
|
/// The remote IP address of the connecting peer.
|
||||||
|
remote_ip: Mutex<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NetBirdCredential {
|
||||||
|
pub fn new(remote_ip: String, session: PipeResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
session: Mutex::new(Some(session)),
|
||||||
|
remote_ip: Mutex::new(remote_ip),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ICredentialProviderCredential_Impl for NetBirdCredential_Impl {
|
||||||
|
fn Advise(&self, _pcpce: Option<&ICredentialProviderCredentialEvents>) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn UnAdvise(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn SetSelected(&self, _pbautologon: *mut BOOL) -> Result<()> {
|
||||||
|
// Auto-logon when this credential is selected
|
||||||
|
unsafe {
|
||||||
|
if !_pbautologon.is_null() {
|
||||||
|
*_pbautologon = TRUE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn SetDeselected(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetFieldState(
|
||||||
|
&self,
|
||||||
|
_dwfieldid: u32,
|
||||||
|
_pcpfs: *mut CREDENTIAL_PROVIDER_FIELD_STATE,
|
||||||
|
_pcpfis: *mut CREDENTIAL_PROVIDER_FIELD_INTERACTIVE_STATE,
|
||||||
|
) -> Result<()> {
|
||||||
|
// We have a single display-only field showing "NetBird Login"
|
||||||
|
unsafe {
|
||||||
|
if !_pcpfs.is_null() {
|
||||||
|
*_pcpfs = CPFS_DISPLAY_IN_SELECTED_TILE;
|
||||||
|
}
|
||||||
|
if !_pcpfis.is_null() {
|
||||||
|
*_pcpfis = CPFIS_NONE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetStringValue(&self, _dwfieldid: u32) -> Result<PWSTR> {
|
||||||
|
let session = self.session.lock().unwrap();
|
||||||
|
let text = if let Some(ref s) = *session {
|
||||||
|
format!("NetBird: Logging in as {}", s.os_user)
|
||||||
|
} else {
|
||||||
|
"NetBird Login".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let wide: Vec<u16> = text.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
let ptr = unsafe {
|
||||||
|
let mem = windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
|
||||||
|
if mem.is_null() {
|
||||||
|
return Err(E_OUTOFMEMORY.into());
|
||||||
|
}
|
||||||
|
std::ptr::copy_nonoverlapping(wide.as_ptr(), mem, wide.len());
|
||||||
|
PWSTR(mem)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetBitmapValue(&self, _dwfieldid: u32) -> Result<HBITMAP> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetCheckboxValue(&self, _dwfieldid: u32, _pbchecked: *mut BOOL, _ppszlabel: *mut PWSTR) -> Result<()> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetSubmitButtonValue(&self, _dwfieldid: u32, _pdwadjacentto: *mut u32) -> Result<()> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetComboBoxValueCount(&self, _dwfieldid: u32, _pcitems: *mut u32, _pdwselecteditem: *mut u32) -> Result<()> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetComboBoxValueAt(&self, _dwfieldid: u32, _dwitem: u32) -> Result<PWSTR> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn SetStringValue(&self, _dwfieldid: u32, _psz: &PCWSTR) -> Result<()> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn SetCheckboxValue(&self, _dwfieldid: u32, _bchecked: BOOL) -> Result<()> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn SetComboBoxSelectedValue(&self, _dwfieldid: u32, _dwselecteditem: u32) -> Result<()> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn CommandLinkClicked(&self, _dwfieldid: u32) -> Result<()> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetSerialization(
|
||||||
|
&self,
|
||||||
|
_pcpgsr: *mut CREDENTIAL_PROVIDER_GET_SERIALIZATION_RESPONSE,
|
||||||
|
_pcpcs: *mut CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
|
||||||
|
_ppszoptionalstatustext: *mut PWSTR,
|
||||||
|
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
|
||||||
|
) -> Result<()> {
|
||||||
|
let session = self.session.lock().unwrap();
|
||||||
|
let session_info = match &*session {
|
||||||
|
Some(s) => s.clone(),
|
||||||
|
None => {
|
||||||
|
unsafe {
|
||||||
|
*_pcpgsr = CPGSR_NO_CREDENTIAL_NOT_FINISHED;
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Consume the session with the agent
|
||||||
|
if let Err(e) = NamedPipeClient::consume_session(&session_info.session_id) {
|
||||||
|
log::error!("Failed to consume RDP session: {}", e);
|
||||||
|
unsafe {
|
||||||
|
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform S4U logon
|
||||||
|
let username = &session_info.os_user;
|
||||||
|
let domain = if session_info.domain.is_empty() {
|
||||||
|
"."
|
||||||
|
} else {
|
||||||
|
&session_info.domain
|
||||||
|
};
|
||||||
|
|
||||||
|
match s4u::generate_s4u_token(username, domain) {
|
||||||
|
Ok(_token) => {
|
||||||
|
// In a full implementation, we would serialize the token into
|
||||||
|
// CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION format
|
||||||
|
// (KerbInteractiveLogon or MsV1_0InteractiveLogon structure).
|
||||||
|
//
|
||||||
|
// For the POC, we signal success. The actual serialization requires
|
||||||
|
// building the proper KERB_INTERACTIVE_LOGON or MSV1_0_INTERACTIVE_LOGON
|
||||||
|
// structure with the token handle, which is complex.
|
||||||
|
//
|
||||||
|
// TODO: Build proper credential serialization from S4U token
|
||||||
|
log::info!(
|
||||||
|
"S4U logon successful for {}\\{}, session {}",
|
||||||
|
domain,
|
||||||
|
username,
|
||||||
|
session_info.session_id
|
||||||
|
);
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
*_pcpgsr = CPGSR_RETURN_CREDENTIAL_FINISHED;
|
||||||
|
// Note: In production, pcpcs would be filled with the serialized credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("S4U logon failed for {}\\{}: {}", domain, username, e);
|
||||||
|
unsafe {
|
||||||
|
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ReportResult(
|
||||||
|
&self,
|
||||||
|
_ntstatus: NTSTATUS,
|
||||||
|
_ntssubstatus: NTSTATUS,
|
||||||
|
_ppszoptionalstatustext: *mut PWSTR,
|
||||||
|
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
11
client/rdp/credprov/src/guid.rs
Normal file
11
client/rdp/credprov/src/guid.rs
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
use windows::core::GUID;
|
||||||
|
|
||||||
|
/// CLSID for the NetBird RDP Credential Provider.
|
||||||
|
/// Generated UUID: {7B3A8E5F-1C4D-4F8A-B2E6-9D0F3A7C5E1B}
|
||||||
|
pub const CLSID_NETBIRD_CREDENTIAL_PROVIDER: GUID = GUID::from_u128(
|
||||||
|
0x7B3A8E5F_1C4D_4F8A_B2E6_9D0F3A7C5E1B,
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Registry path for credential providers.
|
||||||
|
pub const CREDENTIAL_PROVIDER_REGISTRY_PATH: &str =
|
||||||
|
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers";
|
||||||
309
client/rdp/credprov/src/lib.rs
Normal file
309
client/rdp/credprov/src/lib.rs
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
//! NetBird RDP Credential Provider for Windows.
|
||||||
|
//!
|
||||||
|
//! This DLL is a Windows Credential Provider that enables passwordless RDP access
|
||||||
|
//! to machines running the NetBird agent. It is loaded by Windows' LogonUI.exe
|
||||||
|
//! via COM when the login screen is displayed.
|
||||||
|
//!
|
||||||
|
//! ## How it works
|
||||||
|
//!
|
||||||
|
//! 1. The DLL is registered as a Credential Provider in the Windows registry
|
||||||
|
//! 2. When an RDP session begins, LogonUI loads the DLL
|
||||||
|
//! 3. The DLL queries the local NetBird agent via named pipe for pending sessions
|
||||||
|
//! 4. If a pending session exists for the connecting peer, the DLL:
|
||||||
|
//! - Shows a "NetBird Login" credential tile
|
||||||
|
//! - Performs S4U logon to create a Windows token without a password
|
||||||
|
//! - Returns the token to LogonUI for session creation
|
||||||
|
|
||||||
|
mod credential;
|
||||||
|
mod guid;
|
||||||
|
mod named_pipe_client;
|
||||||
|
mod provider;
|
||||||
|
mod s4u;
|
||||||
|
|
||||||
|
use guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
|
||||||
|
use provider::NetBirdCredentialProvider;
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
|
use windows::core::*;
|
||||||
|
use windows::Win32::Foundation::*;
|
||||||
|
use windows::Win32::System::Com::*;
|
||||||
|
|
||||||
|
/// DLL reference count for COM lifecycle management.
|
||||||
|
static DLL_REF_COUNT: AtomicU32 = AtomicU32::new(0);
|
||||||
|
|
||||||
|
/// DLL module handle.
|
||||||
|
static mut DLL_MODULE: HMODULE = HMODULE(std::ptr::null_mut());
|
||||||
|
|
||||||
|
/// COM class factory for creating NetBirdCredentialProvider instances.
|
||||||
|
#[implement(IClassFactory)]
|
||||||
|
struct NetBirdClassFactory;
|
||||||
|
|
||||||
|
impl IClassFactory_Impl for NetBirdClassFactory_Impl {
|
||||||
|
fn CreateInstance(
|
||||||
|
&self,
|
||||||
|
_punkouter: Option<&IUnknown>,
|
||||||
|
riid: *const GUID,
|
||||||
|
ppvobject: *mut *mut std::ffi::c_void,
|
||||||
|
) -> Result<()> {
|
||||||
|
unsafe {
|
||||||
|
if !ppvobject.is_null() {
|
||||||
|
*ppvobject = std::ptr::null_mut();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _punkouter.is_some() {
|
||||||
|
return Err(CLASS_E_NOAGGREGATION.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider = NetBirdCredentialProvider::new();
|
||||||
|
let unknown: IUnknown = provider.into();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
unknown.query(riid, ppvobject).ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn LockServer(&self, flock: BOOL) -> Result<()> {
|
||||||
|
if flock.as_bool() {
|
||||||
|
DLL_REF_COUNT.fetch_add(1, Ordering::SeqCst);
|
||||||
|
} else {
|
||||||
|
DLL_REF_COUNT.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DLL entry point.
|
||||||
|
#[no_mangle]
|
||||||
|
extern "system" fn DllMain(hinstance: HMODULE, reason: u32, _reserved: *mut std::ffi::c_void) -> BOOL {
|
||||||
|
const DLL_PROCESS_ATTACH: u32 = 1;
|
||||||
|
|
||||||
|
if reason == DLL_PROCESS_ATTACH {
|
||||||
|
unsafe {
|
||||||
|
DLL_MODULE = hinstance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TRUE
|
||||||
|
}
|
||||||
|
|
||||||
|
/// COM entry point: returns a class factory for the requested CLSID.
|
||||||
|
#[no_mangle]
|
||||||
|
extern "system" fn DllGetClassObject(
|
||||||
|
rclsid: *const GUID,
|
||||||
|
riid: *const GUID,
|
||||||
|
ppv: *mut *mut std::ffi::c_void,
|
||||||
|
) -> HRESULT {
|
||||||
|
unsafe {
|
||||||
|
if ppv.is_null() {
|
||||||
|
return E_POINTER;
|
||||||
|
}
|
||||||
|
*ppv = std::ptr::null_mut();
|
||||||
|
|
||||||
|
if *rclsid != CLSID_NETBIRD_CREDENTIAL_PROVIDER {
|
||||||
|
return CLASS_E_CLASSNOTAVAILABLE;
|
||||||
|
}
|
||||||
|
|
||||||
|
let factory = NetBirdClassFactory;
|
||||||
|
let unknown: IUnknown = factory.into();
|
||||||
|
|
||||||
|
match unknown.query(riid, ppv) {
|
||||||
|
Ok(()) => S_OK,
|
||||||
|
Err(e) => e.code(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// COM entry point: indicates whether the DLL can be unloaded.
|
||||||
|
#[no_mangle]
|
||||||
|
extern "system" fn DllCanUnloadNow() -> HRESULT {
|
||||||
|
if DLL_REF_COUNT.load(Ordering::SeqCst) == 0 {
|
||||||
|
S_OK
|
||||||
|
} else {
|
||||||
|
S_FALSE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Self-registration: called by regsvr32 to register the credential provider.
|
||||||
|
#[no_mangle]
|
||||||
|
extern "system" fn DllRegisterServer() -> HRESULT {
|
||||||
|
match register_credential_provider(true) {
|
||||||
|
Ok(()) => S_OK,
|
||||||
|
Err(_) => E_FAIL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Self-unregistration: called by regsvr32 /u to unregister the credential provider.
|
||||||
|
#[no_mangle]
|
||||||
|
extern "system" fn DllUnregisterServer() -> HRESULT {
|
||||||
|
match register_credential_provider(false) {
|
||||||
|
Ok(()) => S_OK,
|
||||||
|
Err(_) => E_FAIL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_credential_provider(register: bool) -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
|
use windows::Win32::System::Registry::*;
|
||||||
|
|
||||||
|
let clsid_str = format!("{{{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}",
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data1,
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data2,
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data3,
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[0],
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[1],
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[2],
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[3],
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[4],
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[5],
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[6],
|
||||||
|
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[7],
|
||||||
|
);
|
||||||
|
|
||||||
|
if register {
|
||||||
|
// Register under Credential Providers
|
||||||
|
let cp_key_path = format!(
|
||||||
|
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
|
||||||
|
clsid_str
|
||||||
|
);
|
||||||
|
|
||||||
|
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
let mut hkey = HKEY::default();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let result = RegCreateKeyExW(
|
||||||
|
HKEY_LOCAL_MACHINE,
|
||||||
|
PCWSTR(cp_key_wide.as_ptr()),
|
||||||
|
0,
|
||||||
|
PCWSTR::null(),
|
||||||
|
REG_OPTION_NON_VOLATILE,
|
||||||
|
KEY_WRITE,
|
||||||
|
None,
|
||||||
|
&mut hkey,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
if result.is_err() {
|
||||||
|
return Err("Failed to create credential provider registry key".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let value: Vec<u16> = "NetBird RDP Credential Provider"
|
||||||
|
.encode_utf16()
|
||||||
|
.chain(std::iter::once(0))
|
||||||
|
.collect();
|
||||||
|
let _ = RegSetValueExW(
|
||||||
|
hkey,
|
||||||
|
PCWSTR::null(),
|
||||||
|
0,
|
||||||
|
REG_SZ,
|
||||||
|
Some(std::slice::from_raw_parts(
|
||||||
|
value.as_ptr() as *const u8,
|
||||||
|
value.len() * 2,
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
let _ = RegCloseKey(hkey);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register CLSID in CLSID hive
|
||||||
|
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
|
||||||
|
let clsid_key_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let result = RegCreateKeyExW(
|
||||||
|
HKEY_CLASSES_ROOT,
|
||||||
|
PCWSTR(clsid_key_wide.as_ptr()),
|
||||||
|
0,
|
||||||
|
PCWSTR::null(),
|
||||||
|
REG_OPTION_NON_VOLATILE,
|
||||||
|
KEY_WRITE,
|
||||||
|
None,
|
||||||
|
&mut hkey,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
if result.is_err() {
|
||||||
|
return Err("Failed to create CLSID registry key".into());
|
||||||
|
}
|
||||||
|
let _ = RegCloseKey(hkey);
|
||||||
|
|
||||||
|
// InprocServer32 subkey
|
||||||
|
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
|
||||||
|
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
|
||||||
|
let result = RegCreateKeyExW(
|
||||||
|
HKEY_CLASSES_ROOT,
|
||||||
|
PCWSTR(inproc_wide.as_ptr()),
|
||||||
|
0,
|
||||||
|
PCWSTR::null(),
|
||||||
|
REG_OPTION_NON_VOLATILE,
|
||||||
|
KEY_WRITE,
|
||||||
|
None,
|
||||||
|
&mut hkey,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
if result.is_err() {
|
||||||
|
return Err("Failed to create InprocServer32 registry key".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set DLL path
|
||||||
|
let mut dll_path = [0u16; 260];
|
||||||
|
let len = windows::Win32::System::LibraryLoader::GetModuleFileNameW(
|
||||||
|
DLL_MODULE,
|
||||||
|
&mut dll_path,
|
||||||
|
);
|
||||||
|
if len > 0 {
|
||||||
|
let _ = RegSetValueExW(
|
||||||
|
hkey,
|
||||||
|
PCWSTR::null(),
|
||||||
|
0,
|
||||||
|
REG_SZ,
|
||||||
|
Some(std::slice::from_raw_parts(
|
||||||
|
dll_path.as_ptr() as *const u8,
|
||||||
|
(len as usize + 1) * 2,
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set threading model
|
||||||
|
let threading: Vec<u16> = "Apartment"
|
||||||
|
.encode_utf16()
|
||||||
|
.chain(std::iter::once(0))
|
||||||
|
.collect();
|
||||||
|
let threading_name: Vec<u16> = "ThreadingModel"
|
||||||
|
.encode_utf16()
|
||||||
|
.chain(std::iter::once(0))
|
||||||
|
.collect();
|
||||||
|
let _ = RegSetValueExW(
|
||||||
|
hkey,
|
||||||
|
PCWSTR(threading_name.as_ptr()),
|
||||||
|
0,
|
||||||
|
REG_SZ,
|
||||||
|
Some(std::slice::from_raw_parts(
|
||||||
|
threading.as_ptr() as *const u8,
|
||||||
|
threading.len() * 2,
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = RegCloseKey(hkey);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unregister
|
||||||
|
let cp_key_path = format!(
|
||||||
|
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
|
||||||
|
clsid_str
|
||||||
|
);
|
||||||
|
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let _ = RegDeleteKeyW(HKEY_LOCAL_MACHINE, PCWSTR(cp_key_wide.as_ptr()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
|
||||||
|
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
|
||||||
|
let clsid_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(inproc_wide.as_ptr()));
|
||||||
|
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(clsid_wide.as_ptr()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
135
client/rdp/credprov/src/named_pipe_client.rs
Normal file
135
client/rdp/credprov/src/named_pipe_client.rs
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::io::{Read, Write};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// Named pipe path for communicating with the NetBird agent.
|
||||||
|
const PIPE_NAME: &str = r"\\.\pipe\netbird-rdp-auth";
|
||||||
|
|
||||||
|
/// Maximum response size from the agent.
|
||||||
|
const MAX_RESPONSE_SIZE: usize = 4096;
|
||||||
|
|
||||||
|
/// Timeout for named pipe operations.
|
||||||
|
const PIPE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
|
/// Request sent to the NetBird agent via named pipe.
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct PipeRequest {
|
||||||
|
pub action: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub remote_ip: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response received from the NetBird agent via named pipe.
|
||||||
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
|
pub struct PipeResponse {
|
||||||
|
pub found: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub session_id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub os_user: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub domain: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Client for communicating with the NetBird agent's named pipe server.
|
||||||
|
pub struct NamedPipeClient;
|
||||||
|
|
||||||
|
impl NamedPipeClient {
|
||||||
|
/// Query the NetBird agent for a pending RDP session matching the given remote IP.
|
||||||
|
pub fn query_pending(remote_ip: &str) -> Result<PipeResponse, PipeError> {
|
||||||
|
let request = PipeRequest {
|
||||||
|
action: "query_pending".to_string(),
|
||||||
|
remote_ip: Some(remote_ip.to_string()),
|
||||||
|
session_id: None,
|
||||||
|
};
|
||||||
|
Self::send_request(&request)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tell the NetBird agent to consume (mark as used) a pending session.
|
||||||
|
pub fn consume_session(session_id: &str) -> Result<PipeResponse, PipeError> {
|
||||||
|
let request = PipeRequest {
|
||||||
|
action: "consume".to_string(),
|
||||||
|
remote_ip: None,
|
||||||
|
session_id: Some(session_id.to_string()),
|
||||||
|
};
|
||||||
|
Self::send_request(&request)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send_request(request: &PipeRequest) -> Result<PipeResponse, PipeError> {
|
||||||
|
let request_data =
|
||||||
|
serde_json::to_vec(request).map_err(|e| PipeError::Serialization(e.to_string()))?;
|
||||||
|
|
||||||
|
// Open named pipe (CreateFile in Windows)
|
||||||
|
let mut pipe = Self::open_pipe()?;
|
||||||
|
|
||||||
|
// Write request
|
||||||
|
pipe.write_all(&request_data)
|
||||||
|
.map_err(|e| PipeError::Write(e.to_string()))?;
|
||||||
|
|
||||||
|
// Shutdown write side to signal end of request
|
||||||
|
// For named pipes on Windows, we rely on the message boundary
|
||||||
|
pipe.flush()
|
||||||
|
.map_err(|e| PipeError::Write(e.to_string()))?;
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
let mut response_data = vec![0u8; MAX_RESPONSE_SIZE];
|
||||||
|
let n = pipe
|
||||||
|
.read(&mut response_data)
|
||||||
|
.map_err(|e| PipeError::Read(e.to_string()))?;
|
||||||
|
|
||||||
|
let response: PipeResponse = serde_json::from_slice(&response_data[..n])
|
||||||
|
.map_err(|e| PipeError::Deserialization(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn open_pipe() -> Result<std::fs::File, PipeError> {
|
||||||
|
// On Windows, named pipes are opened like files
|
||||||
|
use std::fs::OpenOptions;
|
||||||
|
|
||||||
|
// Try to open the pipe with a brief retry for PIPE_BUSY
|
||||||
|
for attempt in 0..3 {
|
||||||
|
match OpenOptions::new().read(true).write(true).open(PIPE_NAME) {
|
||||||
|
Ok(file) => return Ok(file),
|
||||||
|
Err(e) => {
|
||||||
|
if attempt < 2 {
|
||||||
|
std::thread::sleep(Duration::from_millis(100));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return Err(PipeError::Connect(format!(
|
||||||
|
"failed to open pipe {}: {}",
|
||||||
|
PIPE_NAME, e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(PipeError::Connect("exhausted pipe connection attempts".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors that can occur during named pipe communication.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum PipeError {
|
||||||
|
Connect(String),
|
||||||
|
Write(String),
|
||||||
|
Read(String),
|
||||||
|
Serialization(String),
|
||||||
|
Deserialization(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for PipeError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
PipeError::Connect(e) => write!(f, "pipe connect: {}", e),
|
||||||
|
PipeError::Write(e) => write!(f, "pipe write: {}", e),
|
||||||
|
PipeError::Read(e) => write!(f, "pipe read: {}", e),
|
||||||
|
PipeError::Serialization(e) => write!(f, "pipe serialization: {}", e),
|
||||||
|
PipeError::Deserialization(e) => write!(f, "pipe deserialization: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for PipeError {}
|
||||||
270
client/rdp/credprov/src/provider.rs
Normal file
270
client/rdp/credprov/src/provider.rs
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
//! ICredentialProvider implementation.
|
||||||
|
//!
|
||||||
|
//! This is the main COM object that Windows' LogonUI.exe instantiates.
|
||||||
|
//! It determines whether to show a "NetBird Login" credential tile based on
|
||||||
|
//! whether the NetBird agent has a pending RDP session for the connecting peer.
|
||||||
|
|
||||||
|
use crate::credential::NetBirdCredential;
|
||||||
|
use crate::guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
|
||||||
|
use crate::named_pipe_client::NamedPipeClient;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use windows::core::*;
|
||||||
|
use windows::Win32::Foundation::*;
|
||||||
|
use windows::Win32::Security::Credentials::*;
|
||||||
|
use windows::Win32::System::RemoteDesktop::*;
|
||||||
|
|
||||||
|
/// The NetBird Credential Provider, loaded by LogonUI.exe via COM.
|
||||||
|
#[implement(ICredentialProvider)]
|
||||||
|
pub struct NetBirdCredentialProvider {
|
||||||
|
/// The credential tile (if a pending session was found).
|
||||||
|
credential: Mutex<Option<ICredentialProviderCredential>>,
|
||||||
|
/// Whether this provider is active for the current usage scenario.
|
||||||
|
active: Mutex<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NetBirdCredentialProvider {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
credential: Mutex::new(None),
|
||||||
|
active: Mutex::new(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ICredentialProvider_Impl for NetBirdCredentialProvider_Impl {
|
||||||
|
fn SetUsageScenario(
|
||||||
|
&self,
|
||||||
|
cpus: CREDENTIAL_PROVIDER_USAGE_SCENARIO,
|
||||||
|
_dwflags: u32,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut active = self.active.lock().unwrap();
|
||||||
|
|
||||||
|
match cpus {
|
||||||
|
CPUS_LOGON | CPUS_UNLOCK_WORKSTATION => {
|
||||||
|
// We activate for RDP logon and unlock scenarios
|
||||||
|
*active = true;
|
||||||
|
log::info!("NetBird CP activated for usage scenario {:?}", cpus.0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Don't activate for credui or other scenarios
|
||||||
|
*active = false;
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn SetSerialization(
|
||||||
|
&self,
|
||||||
|
_pcpcs: *const CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
|
||||||
|
) -> Result<()> {
|
||||||
|
Err(E_NOTIMPL.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn Advise(
|
||||||
|
&self,
|
||||||
|
_pcpe: Option<&ICredentialProviderEvents>,
|
||||||
|
_upadvisecontext: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn UnAdvise(&self) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetFieldDescriptorCount(&self) -> Result<u32> {
|
||||||
|
// We have one field: a large text label showing "NetBird: Logging in as <user>"
|
||||||
|
Ok(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetFieldDescriptorAt(
|
||||||
|
&self,
|
||||||
|
_dwindex: u32,
|
||||||
|
_ppcpfd: *mut *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR,
|
||||||
|
) -> Result<()> {
|
||||||
|
if _dwindex != 0 {
|
||||||
|
return Err(E_INVALIDARG.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let label = "NetBird Login";
|
||||||
|
let wide: Vec<u16> = label.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let desc = windows::Win32::System::Com::CoTaskMemAlloc(
|
||||||
|
std::mem::size_of::<CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR>(),
|
||||||
|
) as *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR;
|
||||||
|
|
||||||
|
if desc.is_null() {
|
||||||
|
return Err(E_OUTOFMEMORY.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let label_mem =
|
||||||
|
windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
|
||||||
|
if label_mem.is_null() {
|
||||||
|
windows::Win32::System::Com::CoTaskMemFree(Some(desc as *const _));
|
||||||
|
return Err(E_OUTOFMEMORY.into());
|
||||||
|
}
|
||||||
|
std::ptr::copy_nonoverlapping(wide.as_ptr(), label_mem, wide.len());
|
||||||
|
|
||||||
|
(*desc).dwFieldID = 0;
|
||||||
|
(*desc).cpft = CPFT_LARGE_TEXT;
|
||||||
|
(*desc).pszLabel = PWSTR(label_mem);
|
||||||
|
(*desc).guidFieldType = GUID::zeroed();
|
||||||
|
|
||||||
|
*_ppcpfd = desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetCredentialCount(
|
||||||
|
&self,
|
||||||
|
_pdwcount: *mut u32,
|
||||||
|
_pdwdefault: *mut u32,
|
||||||
|
_pbautologinwithdefault: *mut BOOL,
|
||||||
|
) -> Result<()> {
|
||||||
|
let active = self.active.lock().unwrap();
|
||||||
|
if !*active {
|
||||||
|
unsafe {
|
||||||
|
*_pdwcount = 0;
|
||||||
|
*_pdwdefault = u32::MAX;
|
||||||
|
*_pbautologinwithdefault = FALSE;
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get the client IP of the current RDP session
|
||||||
|
let remote_ip = match get_rdp_client_ip() {
|
||||||
|
Some(ip) => ip,
|
||||||
|
None => {
|
||||||
|
log::debug!("NetBird CP: could not determine RDP client IP");
|
||||||
|
unsafe {
|
||||||
|
*_pdwcount = 0;
|
||||||
|
*_pdwdefault = u32::MAX;
|
||||||
|
*_pbautologinwithdefault = FALSE;
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Query the NetBird agent for a pending session
|
||||||
|
match NamedPipeClient::query_pending(&remote_ip) {
|
||||||
|
Ok(response) if response.found => {
|
||||||
|
log::info!(
|
||||||
|
"NetBird CP: found pending session for {} -> {}",
|
||||||
|
remote_ip,
|
||||||
|
response.os_user
|
||||||
|
);
|
||||||
|
|
||||||
|
let cred = NetBirdCredential::new(remote_ip, response);
|
||||||
|
let icred: ICredentialProviderCredential = cred.into();
|
||||||
|
|
||||||
|
let mut credential = self.credential.lock().unwrap();
|
||||||
|
*credential = Some(icred);
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
*_pdwcount = 1;
|
||||||
|
*_pdwdefault = 0;
|
||||||
|
*_pbautologinwithdefault = TRUE; // auto-logon
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => {
|
||||||
|
log::debug!("NetBird CP: no pending session for {}", remote_ip);
|
||||||
|
unsafe {
|
||||||
|
*_pdwcount = 0;
|
||||||
|
*_pdwdefault = u32::MAX;
|
||||||
|
*_pbautologinwithdefault = FALSE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::debug!("NetBird CP: pipe query failed: {}", e);
|
||||||
|
unsafe {
|
||||||
|
*_pdwcount = 0;
|
||||||
|
*_pdwdefault = u32::MAX;
|
||||||
|
*_pbautologinwithdefault = FALSE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn GetCredentialAt(
|
||||||
|
&self,
|
||||||
|
_dwindex: u32,
|
||||||
|
_ppcpc: *mut Option<ICredentialProviderCredential>,
|
||||||
|
) -> Result<()> {
|
||||||
|
if _dwindex != 0 {
|
||||||
|
return Err(E_INVALIDARG.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let credential = self.credential.lock().unwrap();
|
||||||
|
match &*credential {
|
||||||
|
Some(cred) => {
|
||||||
|
unsafe {
|
||||||
|
*_ppcpc = Some(cred.clone());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
None => Err(E_UNEXPECTED.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the IP address of the remote RDP client for the current session.
|
||||||
|
fn get_rdp_client_ip() -> Option<String> {
|
||||||
|
unsafe {
|
||||||
|
// Get the current session ID
|
||||||
|
let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
|
||||||
|
let mut session_id = 0u32;
|
||||||
|
|
||||||
|
if !windows::Win32::System::RemoteDesktop::ProcessIdToSessionId(process_id, &mut session_id)
|
||||||
|
.as_bool()
|
||||||
|
{
|
||||||
|
log::debug!("ProcessIdToSessionId failed");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query the client address
|
||||||
|
let mut buffer: *mut WTS_CLIENT_ADDRESS = std::ptr::null_mut();
|
||||||
|
let mut bytes_returned = 0u32;
|
||||||
|
|
||||||
|
let result = WTSQuerySessionInformationW(
|
||||||
|
WTS_CURRENT_SERVER_HANDLE,
|
||||||
|
session_id,
|
||||||
|
WTS_INFO_CLASS(14), // WTSClientAddress
|
||||||
|
&mut buffer as *mut _ as *mut *mut u16,
|
||||||
|
&mut bytes_returned,
|
||||||
|
);
|
||||||
|
|
||||||
|
if !result.as_bool() || buffer.is_null() {
|
||||||
|
log::debug!("WTSQuerySessionInformation(WTSClientAddress) failed");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let client_addr = &*buffer;
|
||||||
|
let ip = match client_addr.AddressFamily as u32 {
|
||||||
|
// AF_INET
|
||||||
|
2 => {
|
||||||
|
let addr = &client_addr.Address;
|
||||||
|
Some(format!("{}.{}.{}.{}", addr[2], addr[3], addr[4], addr[5]))
|
||||||
|
}
|
||||||
|
// AF_INET6
|
||||||
|
23 => {
|
||||||
|
// IPv6 - extract from Address bytes
|
||||||
|
let addr = &client_addr.Address;
|
||||||
|
Some(format!(
|
||||||
|
"{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
|
||||||
|
addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9],
|
||||||
|
addr[10], addr[11], addr[12], addr[13], addr[14], addr[15], addr[16], addr[17]
|
||||||
|
))
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
WTSFreeMemory(buffer as *mut std::ffi::c_void);
|
||||||
|
|
||||||
|
ip
|
||||||
|
}
|
||||||
|
}
|
||||||
398
client/rdp/credprov/src/s4u.rs
Normal file
398
client/rdp/credprov/src/s4u.rs
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
//! S4U (Service for User) authentication for Windows.
|
||||||
|
//!
|
||||||
|
//! This module ports the S4U logon logic from the Go implementation at:
|
||||||
|
//! `client/ssh/server/executor_windows.go:generateS4UUserToken()`
|
||||||
|
//!
|
||||||
|
//! It creates Windows logon tokens without requiring a password, using the LSA
|
||||||
|
//! (Local Security Authority) S4U mechanism. This is the same approach used by
|
||||||
|
//! OpenSSH for Windows for public key authentication.
|
||||||
|
|
||||||
|
use std::ptr;
|
||||||
|
use windows::core::{PCSTR, PWSTR};
|
||||||
|
use windows::Win32::Foundation::{HANDLE, LUID, NTSTATUS, PSID};
|
||||||
|
use windows::Win32::Security::Authentication::Identity::{
|
||||||
|
LsaDeregisterLogonProcess, LsaFreeReturnBuffer, LsaLogonUser, LsaLookupAuthenticationPackage,
|
||||||
|
LsaRegisterLogonProcess, KERB_S4U_LOGON, MSV1_0_S4U_LOGON, MSV1_0_S4U_LOGON_FLAG_CHECK_LOGONHOURS,
|
||||||
|
SECURITY_LOGON_TYPE,
|
||||||
|
};
|
||||||
|
use windows::Win32::Security::{
|
||||||
|
QUOTA_LIMITS, TOKEN_SOURCE,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Status code for successful LSA operations.
|
||||||
|
const STATUS_SUCCESS: i32 = 0;
|
||||||
|
|
||||||
|
/// Network logon type (used for S4U).
|
||||||
|
const LOGON32_LOGON_NETWORK: SECURITY_LOGON_TYPE = SECURITY_LOGON_TYPE(3);
|
||||||
|
|
||||||
|
/// Kerberos S4U logon message type.
|
||||||
|
const KERB_S4U_LOGON_TYPE: u32 = 12;
|
||||||
|
|
||||||
|
/// MSV1_0 S4U logon message type.
|
||||||
|
const MSV1_0_S4U_LOGON_TYPE: u32 = 12;
|
||||||
|
|
||||||
|
/// Authentication package name for Kerberos.
|
||||||
|
const KERBEROS_PACKAGE: &str = "Kerberos";
|
||||||
|
|
||||||
|
/// Authentication package name for MSV1_0 (local users).
|
||||||
|
const MSV1_0_PACKAGE: &str = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0";
|
||||||
|
|
||||||
|
/// Result of a successful S4U logon.
|
||||||
|
pub struct S4UToken {
|
||||||
|
pub handle: HANDLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for S4UToken {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.handle.is_invalid() {
|
||||||
|
unsafe {
|
||||||
|
let _ = windows::Win32::Foundation::CloseHandle(self.handle);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors from S4U logon operations.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum S4UError {
|
||||||
|
LsaRegister(NTSTATUS),
|
||||||
|
LookupPackage(NTSTATUS),
|
||||||
|
LogonUser(NTSTATUS, i32),
|
||||||
|
AllocateLuid,
|
||||||
|
InvalidUsername(String),
|
||||||
|
Utf16Conversion(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for S4UError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
S4UError::LsaRegister(s) => write!(f, "LsaRegisterLogonProcess: 0x{:x}", s.0),
|
||||||
|
S4UError::LookupPackage(s) => write!(f, "LsaLookupAuthenticationPackage: 0x{:x}", s.0),
|
||||||
|
S4UError::LogonUser(s, sub) => {
|
||||||
|
write!(f, "LsaLogonUser S4U: NTSTATUS=0x{:x}, SubStatus=0x{:x}", s.0, sub)
|
||||||
|
}
|
||||||
|
S4UError::AllocateLuid => write!(f, "AllocateLocallyUniqueId failed"),
|
||||||
|
S4UError::InvalidUsername(u) => write!(f, "invalid username: {}", u),
|
||||||
|
S4UError::Utf16Conversion(s) => write!(f, "UTF-16 conversion: {}", s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for S4UError {}
|
||||||
|
|
||||||
|
/// Generate a Windows logon token using S4U authentication.
|
||||||
|
///
|
||||||
|
/// This creates a token for the specified user without requiring a password.
|
||||||
|
/// The calling process must have SeTcbPrivilege (typically SYSTEM).
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `username` - The Windows username (without domain prefix)
|
||||||
|
/// * `domain` - The domain name ("." for local users)
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// An `S4UToken` containing the Windows logon token handle.
|
||||||
|
pub fn generate_s4u_token(username: &str, domain: &str) -> Result<S4UToken, S4UError> {
|
||||||
|
if username.is_empty() {
|
||||||
|
return Err(S4UError::InvalidUsername("empty username".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_local = is_local_user(domain);
|
||||||
|
|
||||||
|
// Initialize LSA connection
|
||||||
|
let lsa_handle = initialize_lsa_connection()?;
|
||||||
|
|
||||||
|
// Lookup authentication package
|
||||||
|
let auth_package_id = lookup_auth_package(lsa_handle, is_local)?;
|
||||||
|
|
||||||
|
// Perform S4U logon
|
||||||
|
let result = perform_s4u_logon(lsa_handle, auth_package_id, username, domain, is_local);
|
||||||
|
|
||||||
|
// Cleanup LSA connection
|
||||||
|
unsafe {
|
||||||
|
let _ = LsaDeregisterLogonProcess(lsa_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_local_user(domain: &str) -> bool {
|
||||||
|
domain.is_empty() || domain == "."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initialize_lsa_connection() -> Result<HANDLE, S4UError> {
|
||||||
|
let process_name = "NetBird\0";
|
||||||
|
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||||
|
Length: (process_name.len() - 1) as u16,
|
||||||
|
MaximumLength: process_name.len() as u16,
|
||||||
|
Buffer: windows::core::PSTR(process_name.as_ptr() as *mut u8),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut lsa_handle = HANDLE::default();
|
||||||
|
let mut mode = 0u32;
|
||||||
|
|
||||||
|
let status = unsafe {
|
||||||
|
LsaRegisterLogonProcess(&mut lsa_string, &mut lsa_handle, &mut mode)
|
||||||
|
};
|
||||||
|
|
||||||
|
if status.0 != STATUS_SUCCESS {
|
||||||
|
return Err(S4UError::LsaRegister(status));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(lsa_handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lookup_auth_package(lsa_handle: HANDLE, is_local: bool) -> Result<u32, S4UError> {
|
||||||
|
let package_name = if is_local { MSV1_0_PACKAGE } else { KERBEROS_PACKAGE };
|
||||||
|
let package_with_null = format!("{}\0", package_name);
|
||||||
|
|
||||||
|
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||||
|
Length: (package_with_null.len() - 1) as u16,
|
||||||
|
MaximumLength: package_with_null.len() as u16,
|
||||||
|
Buffer: windows::core::PSTR(package_with_null.as_ptr() as *mut u8),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut auth_package_id = 0u32;
|
||||||
|
let status = unsafe {
|
||||||
|
LsaLookupAuthenticationPackage(lsa_handle, &mut lsa_string, &mut auth_package_id)
|
||||||
|
};
|
||||||
|
|
||||||
|
if status.0 != STATUS_SUCCESS {
|
||||||
|
return Err(S4UError::LookupPackage(status));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(auth_package_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn perform_s4u_logon(
|
||||||
|
lsa_handle: HANDLE,
|
||||||
|
auth_package_id: u32,
|
||||||
|
username: &str,
|
||||||
|
domain: &str,
|
||||||
|
is_local: bool,
|
||||||
|
) -> Result<S4UToken, S4UError> {
|
||||||
|
// Prepare token source
|
||||||
|
let mut source_name = [0u8; 8];
|
||||||
|
let name_bytes = b"netbird";
|
||||||
|
source_name[..name_bytes.len()].copy_from_slice(name_bytes);
|
||||||
|
|
||||||
|
let mut source_id = LUID::default();
|
||||||
|
let alloc_ok = unsafe {
|
||||||
|
windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime(
|
||||||
|
&mut std::mem::zeroed(),
|
||||||
|
);
|
||||||
|
// Use a simpler approach - just use the current time as a unique ID
|
||||||
|
source_id.LowPart = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.subsec_nanos();
|
||||||
|
source_id.HighPart = std::process::id() as i32;
|
||||||
|
true
|
||||||
|
};
|
||||||
|
|
||||||
|
if !alloc_ok {
|
||||||
|
return Err(S4UError::AllocateLuid);
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_source = TOKEN_SOURCE {
|
||||||
|
SourceName: source_name,
|
||||||
|
SourceIdentifier: source_id,
|
||||||
|
};
|
||||||
|
|
||||||
|
let origin_name_str = "netbird\0";
|
||||||
|
let mut origin_name = windows::Win32::Security::Authentication::Identity::LSA_STRING {
|
||||||
|
Length: (origin_name_str.len() - 1) as u16,
|
||||||
|
MaximumLength: origin_name_str.len() as u16,
|
||||||
|
Buffer: windows::core::PSTR(origin_name_str.as_ptr() as *mut u8),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build the logon info structure
|
||||||
|
let (logon_info_ptr, logon_info_size) = if is_local {
|
||||||
|
build_msv1_0_s4u_logon(username)?
|
||||||
|
} else {
|
||||||
|
build_kerb_s4u_logon(username, domain)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut profile: *mut std::ffi::c_void = ptr::null_mut();
|
||||||
|
let mut profile_size = 0u32;
|
||||||
|
let mut logon_id = LUID::default();
|
||||||
|
let mut token = HANDLE::default();
|
||||||
|
let mut quotas = QUOTA_LIMITS::default();
|
||||||
|
let mut sub_status: i32 = 0;
|
||||||
|
|
||||||
|
let status = unsafe {
|
||||||
|
LsaLogonUser(
|
||||||
|
lsa_handle,
|
||||||
|
&mut origin_name,
|
||||||
|
LOGON32_LOGON_NETWORK,
|
||||||
|
auth_package_id,
|
||||||
|
logon_info_ptr as *const std::ffi::c_void,
|
||||||
|
logon_info_size as u32,
|
||||||
|
None, // local groups
|
||||||
|
&token_source,
|
||||||
|
&mut profile,
|
||||||
|
&mut profile_size,
|
||||||
|
&mut logon_id,
|
||||||
|
&mut token,
|
||||||
|
&mut quotas,
|
||||||
|
&mut sub_status,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Free profile buffer if allocated
|
||||||
|
if !profile.is_null() {
|
||||||
|
unsafe {
|
||||||
|
let _ = LsaFreeReturnBuffer(profile);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free the logon info buffer
|
||||||
|
unsafe {
|
||||||
|
let layout = std::alloc::Layout::from_size_align_unchecked(logon_info_size, 8);
|
||||||
|
std::alloc::dealloc(logon_info_ptr as *mut u8, layout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.0 != STATUS_SUCCESS {
|
||||||
|
return Err(S4UError::LogonUser(status, sub_status));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(S4UToken { handle: token })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build MSV1_0_S4U_LOGON structure for local users.
|
||||||
|
fn build_msv1_0_s4u_logon(username: &str) -> Result<(*mut u8, usize), S4UError> {
|
||||||
|
let username_utf16: Vec<u16> = username.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
let domain_utf16: Vec<u16> = ".".encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
|
||||||
|
let username_byte_size = username_utf16.len() * 2;
|
||||||
|
let domain_byte_size = domain_utf16.len() * 2;
|
||||||
|
|
||||||
|
// MSV1_0_S4U_LOGON structure:
|
||||||
|
// MessageType: u32 (4 bytes)
|
||||||
|
// Flags: u32 (4 bytes)
|
||||||
|
// UserPrincipalName: UNICODE_STRING (8 bytes on 32-bit, 16 bytes on 64-bit)
|
||||||
|
// DomainName: UNICODE_STRING
|
||||||
|
let struct_size = std::mem::size_of::<MSV1_0_S4U_LOGON_HEADER>();
|
||||||
|
let total_size = struct_size + username_byte_size + domain_byte_size;
|
||||||
|
|
||||||
|
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
|
||||||
|
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
|
||||||
|
|
||||||
|
if buffer.is_null() {
|
||||||
|
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the POC, we'll set up the raw bytes manually since the windows-rs
|
||||||
|
// MSV1_0_S4U_LOGON structure layout may differ.
|
||||||
|
// This is a simplified version - in production, use proper FFI bindings.
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
// MessageType = MSV1_0_S4U_LOGON_TYPE (12)
|
||||||
|
*(buffer as *mut u32) = MSV1_0_S4U_LOGON_TYPE;
|
||||||
|
// Flags = 0
|
||||||
|
*((buffer as *mut u32).add(1)) = 0;
|
||||||
|
|
||||||
|
// Copy username UTF-16 after the structure
|
||||||
|
let username_offset = struct_size;
|
||||||
|
let username_dest = buffer.add(username_offset);
|
||||||
|
ptr::copy_nonoverlapping(
|
||||||
|
username_utf16.as_ptr() as *const u8,
|
||||||
|
username_dest,
|
||||||
|
username_byte_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Copy domain UTF-16 after username
|
||||||
|
let domain_offset = username_offset + username_byte_size;
|
||||||
|
let domain_dest = buffer.add(domain_offset);
|
||||||
|
ptr::copy_nonoverlapping(
|
||||||
|
domain_utf16.as_ptr() as *const u8,
|
||||||
|
domain_dest,
|
||||||
|
domain_byte_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Set UNICODE_STRING for UserPrincipalName (offset 8 on 64-bit)
|
||||||
|
// Length, MaximumLength, Buffer pointer
|
||||||
|
let upn_ptr = buffer.add(8) as *mut u16;
|
||||||
|
*upn_ptr = ((username_utf16.len() - 1) * 2) as u16; // Length (without null)
|
||||||
|
*(upn_ptr.add(1)) = (username_utf16.len() * 2) as u16; // MaximumLength
|
||||||
|
*((buffer.add(8 + 4)) as *mut *const u8) = username_dest; // Buffer
|
||||||
|
|
||||||
|
// Set UNICODE_STRING for DomainName
|
||||||
|
let dn_offset = 8 + std::mem::size_of::<UnicodeStringRaw>();
|
||||||
|
let dn_ptr = buffer.add(dn_offset) as *mut u16;
|
||||||
|
*dn_ptr = ((domain_utf16.len() - 1) * 2) as u16;
|
||||||
|
*(dn_ptr.add(1)) = (domain_utf16.len() * 2) as u16;
|
||||||
|
*((buffer.add(dn_offset + 4)) as *mut *const u8) = domain_dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((buffer, total_size))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build KERB_S4U_LOGON structure for domain users.
|
||||||
|
fn build_kerb_s4u_logon(username: &str, domain: &str) -> Result<(*mut u8, usize), S4UError> {
|
||||||
|
// Build UPN: username@domain
|
||||||
|
let upn = format!("{}@{}", username, domain);
|
||||||
|
let upn_utf16: Vec<u16> = upn.encode_utf16().chain(std::iter::once(0)).collect();
|
||||||
|
let upn_byte_size = upn_utf16.len() * 2;
|
||||||
|
|
||||||
|
let struct_size = std::mem::size_of::<KerbS4ULogonHeader>();
|
||||||
|
let total_size = struct_size + upn_byte_size;
|
||||||
|
|
||||||
|
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
|
||||||
|
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
|
||||||
|
|
||||||
|
if buffer.is_null() {
|
||||||
|
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
// MessageType = KERB_S4U_LOGON_TYPE (12)
|
||||||
|
*(buffer as *mut u32) = KERB_S4U_LOGON_TYPE;
|
||||||
|
// Flags = 0
|
||||||
|
*((buffer as *mut u32).add(1)) = 0;
|
||||||
|
|
||||||
|
// Copy UPN UTF-16 after the structure
|
||||||
|
let upn_offset = struct_size;
|
||||||
|
let upn_dest = buffer.add(upn_offset);
|
||||||
|
ptr::copy_nonoverlapping(
|
||||||
|
upn_utf16.as_ptr() as *const u8,
|
||||||
|
upn_dest,
|
||||||
|
upn_byte_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Set UNICODE_STRING for ClientUpn (offset 8)
|
||||||
|
let upn_str_ptr = buffer.add(8) as *mut u16;
|
||||||
|
*upn_str_ptr = ((upn_utf16.len() - 1) * 2) as u16;
|
||||||
|
*(upn_str_ptr.add(1)) = (upn_utf16.len() * 2) as u16;
|
||||||
|
*((buffer.add(8 + 4)) as *mut *const u8) = upn_dest;
|
||||||
|
|
||||||
|
// ClientRealm is empty (zeroed)
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((buffer, total_size))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw UNICODE_STRING layout for size calculation.
|
||||||
|
#[repr(C)]
|
||||||
|
struct UnicodeStringRaw {
|
||||||
|
_length: u16,
|
||||||
|
_maximum_length: u16,
|
||||||
|
_buffer: *const u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Header size for MSV1_0_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
|
||||||
|
#[repr(C)]
|
||||||
|
struct MSV1_0_S4U_LOGON_HEADER {
|
||||||
|
_message_type: u32,
|
||||||
|
_flags: u32,
|
||||||
|
_user_principal_name: UnicodeStringRaw,
|
||||||
|
_domain_name: UnicodeStringRaw,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Header size for KERB_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
|
||||||
|
#[repr(C)]
|
||||||
|
struct KerbS4ULogonHeader {
|
||||||
|
_message_type: u32,
|
||||||
|
_flags: u32,
|
||||||
|
_client_upn: UnicodeStringRaw,
|
||||||
|
_client_realm: UnicodeStringRaw,
|
||||||
|
}
|
||||||
21
client/rdp/server/addr.go
Normal file
21
client/rdp/server/addr.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseAddr parses a string into a netip.Addr, stripping any port or zone.
|
||||||
|
func parseAddr(s string) (netip.Addr, error) {
|
||||||
|
// Try as plain IP first
|
||||||
|
if addr, err := netip.ParseAddr(s); err == nil {
|
||||||
|
return addr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try as IP:port
|
||||||
|
if addrPort, err := netip.ParseAddrPort(s); err == nil {
|
||||||
|
return addrPort.Addr(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.Addr{}, fmt.Errorf("invalid IP address: %s", s)
|
||||||
|
}
|
||||||
13
client/rdp/server/credprov_stub.go
Normal file
13
client/rdp/server/credprov_stub.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
// RegisterCredentialProvider is a no-op on non-Windows platforms.
|
||||||
|
func RegisterCredentialProvider() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterCredentialProvider is a no-op on non-Windows platforms.
|
||||||
|
func UnregisterCredentialProvider() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
66
client/rdp/server/credprov_windows.go
Normal file
66
client/rdp/server/credprov_windows.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// credProvDLLName is the filename of the credential provider DLL.
|
||||||
|
credProvDLLName = "netbird_credprov.dll"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterCredentialProvider registers the NetBird Credential Provider COM DLL
|
||||||
|
// using regsvr32. The DLL must be shipped alongside the NetBird executable.
|
||||||
|
func RegisterCredentialProvider() error {
|
||||||
|
dllPath, err := findCredProvDLL()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("find credential provider DLL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("regsvr32", "/s", dllPath)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("regsvr32 %s: %w (output: %s)", dllPath, err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("registered RDP credential provider: %s", dllPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterCredentialProvider unregisters the NetBird Credential Provider COM DLL.
|
||||||
|
func UnregisterCredentialProvider() error {
|
||||||
|
dllPath, err := findCredProvDLL()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("credential provider DLL not found for unregistration: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("regsvr32", "/s", "/u", dllPath)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("regsvr32 /u %s: %w (output: %s)", dllPath, err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("unregistered RDP credential provider: %s", dllPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findCredProvDLL locates the credential provider DLL next to the running executable.
|
||||||
|
func findCredProvDLL() (string, error) {
|
||||||
|
exePath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get executable path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dllPath := filepath.Join(filepath.Dir(exePath), credProvDLLName)
|
||||||
|
if _, err := os.Stat(dllPath); err != nil {
|
||||||
|
return "", fmt.Errorf("DLL not found at %s: %w", dllPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dllPath, nil
|
||||||
|
}
|
||||||
184
client/rdp/server/pending.go
Normal file
184
client/rdp/server/pending.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultSessionTTL is the default time-to-live for pending RDP sessions.
|
||||||
|
DefaultSessionTTL = 60 * time.Second
|
||||||
|
|
||||||
|
// cleanupInterval is how often the store checks for expired sessions.
|
||||||
|
cleanupInterval = 10 * time.Second
|
||||||
|
|
||||||
|
// nonceLength is the length of the nonce in bytes.
|
||||||
|
nonceLength = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
// PendingRDPSession represents an authorized but not yet consumed RDP session.
|
||||||
|
type PendingRDPSession struct {
|
||||||
|
SessionID string
|
||||||
|
PeerIP netip.Addr
|
||||||
|
OSUsername string
|
||||||
|
Domain string
|
||||||
|
JWTUserID string // for audit trail
|
||||||
|
Nonce string // replay protection
|
||||||
|
CreatedAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
consumed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PendingStore manages pending RDP session entries with automatic expiration.
|
||||||
|
type PendingStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
sessions map[string]*PendingRDPSession // keyed by SessionID
|
||||||
|
nonces map[string]struct{} // seen nonces for replay protection
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPendingStore creates a new pending session store with the given TTL.
|
||||||
|
func NewPendingStore(ttl time.Duration) *PendingStore {
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = DefaultSessionTTL
|
||||||
|
}
|
||||||
|
return &PendingStore{
|
||||||
|
sessions: make(map[string]*PendingRDPSession),
|
||||||
|
nonces: make(map[string]struct{}),
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add creates a new pending RDP session and returns it.
|
||||||
|
func (ps *PendingStore) Add(peerIP netip.Addr, osUsername, domain, jwtUserID, nonce string) (*PendingRDPSession, error) {
|
||||||
|
ps.mu.Lock()
|
||||||
|
defer ps.mu.Unlock()
|
||||||
|
|
||||||
|
// Check nonce for replay protection
|
||||||
|
if _, seen := ps.nonces[nonce]; seen {
|
||||||
|
return nil, fmt.Errorf("duplicate nonce: replay detected")
|
||||||
|
}
|
||||||
|
ps.nonces[nonce] = struct{}{}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
session := &PendingRDPSession{
|
||||||
|
SessionID: uuid.New().String(),
|
||||||
|
PeerIP: peerIP,
|
||||||
|
OSUsername: osUsername,
|
||||||
|
Domain: domain,
|
||||||
|
JWTUserID: jwtUserID,
|
||||||
|
Nonce: nonce,
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(ps.ttl),
|
||||||
|
}
|
||||||
|
|
||||||
|
ps.sessions[session.SessionID] = session
|
||||||
|
|
||||||
|
log.Debugf("RDP pending session created: id=%s peer=%s user=%s domain=%s expires=%s",
|
||||||
|
session.SessionID, peerIP, osUsername, domain, session.ExpiresAt.Format(time.RFC3339))
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryByPeerIP finds the first non-consumed, non-expired pending session for the given peer IP.
|
||||||
|
func (ps *PendingStore) QueryByPeerIP(peerIP netip.Addr) (*PendingRDPSession, bool) {
|
||||||
|
ps.mu.RLock()
|
||||||
|
defer ps.mu.RUnlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for _, session := range ps.sessions {
|
||||||
|
if session.PeerIP == peerIP && !session.consumed && now.Before(session.ExpiresAt) {
|
||||||
|
return session, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume marks a session as consumed (single-use). Returns true if the session
|
||||||
|
// was found and successfully consumed, false if it was already consumed, expired, or not found.
|
||||||
|
func (ps *PendingStore) Consume(sessionID string) bool {
|
||||||
|
ps.mu.Lock()
|
||||||
|
defer ps.mu.Unlock()
|
||||||
|
|
||||||
|
session, exists := ps.sessions[sessionID]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.consumed {
|
||||||
|
log.Debugf("RDP pending session already consumed: id=%s", sessionID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(session.ExpiresAt) {
|
||||||
|
log.Debugf("RDP pending session expired: id=%s", sessionID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
session.consumed = true
|
||||||
|
log.Debugf("RDP pending session consumed: id=%s peer=%s user=%s",
|
||||||
|
sessionID, session.PeerIP, session.OSUsername)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartCleanup runs a background goroutine that periodically removes expired sessions.
|
||||||
|
func (ps *PendingStore) StartCleanup(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(cleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
ps.cleanup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup removes expired and consumed sessions.
|
||||||
|
func (ps *PendingStore) cleanup() {
|
||||||
|
ps.mu.Lock()
|
||||||
|
defer ps.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for id, session := range ps.sessions {
|
||||||
|
if now.After(session.ExpiresAt) || session.consumed {
|
||||||
|
delete(ps.sessions, id)
|
||||||
|
delete(ps.nonces, session.Nonce)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of active (non-expired, non-consumed) sessions.
|
||||||
|
func (ps *PendingStore) Count() int {
|
||||||
|
ps.mu.RLock()
|
||||||
|
defer ps.mu.RUnlock()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
now := time.Now()
|
||||||
|
for _, session := range ps.sessions {
|
||||||
|
if !session.consumed && now.Before(session.ExpiresAt) {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateNonce creates a cryptographically random nonce for replay protection.
|
||||||
|
func GenerateNonce() (string, error) {
|
||||||
|
b := make([]byte, nonceLength)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", fmt.Errorf("generate nonce: %w", err)
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b), nil
|
||||||
|
}
|
||||||
268
client/rdp/server/pending_test.go
Normal file
268
client/rdp/server/pending_test.go
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPendingStore_AddAndQuery(t *testing.T) {
|
||||||
|
store := NewPendingStore(DefaultSessionTTL)
|
||||||
|
|
||||||
|
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.SessionID == "" {
|
||||||
|
t.Fatal("expected non-empty session ID")
|
||||||
|
}
|
||||||
|
if session.PeerIP != peerIP {
|
||||||
|
t.Errorf("expected peer IP %s, got %s", peerIP, session.PeerIP)
|
||||||
|
}
|
||||||
|
if session.OSUsername != "admin" {
|
||||||
|
t.Errorf("expected username admin, got %s", session.OSUsername)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query should find the session
|
||||||
|
found, ok := store.QueryByPeerIP(peerIP)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected to find pending session")
|
||||||
|
}
|
||||||
|
if found.SessionID != session.SessionID {
|
||||||
|
t.Errorf("expected session %s, got %s", session.SessionID, found.SessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query for different IP should not find anything
|
||||||
|
_, ok = store.QueryByPeerIP(netip.MustParseAddr("100.64.0.2"))
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected no session for different IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPendingStore_Consume(t *testing.T) {
|
||||||
|
store := NewPendingStore(DefaultSessionTTL)
|
||||||
|
|
||||||
|
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First consume should succeed
|
||||||
|
if !store.Consume(session.SessionID) {
|
||||||
|
t.Fatal("expected first consume to succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second consume should fail (already consumed)
|
||||||
|
if store.Consume(session.SessionID) {
|
||||||
|
t.Fatal("expected second consume to fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query should no longer find consumed session
|
||||||
|
_, ok := store.QueryByPeerIP(peerIP)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected consumed session to not be found by query")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPendingStore_Expiry(t *testing.T) {
|
||||||
|
store := NewPendingStore(50 * time.Millisecond)
|
||||||
|
|
||||||
|
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-3")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be found immediately
|
||||||
|
_, ok := store.QueryByPeerIP(peerIP)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected to find session before expiry")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for expiry
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Should not be found after expiry
|
||||||
|
_, ok = store.QueryByPeerIP(peerIP)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected session to be expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume should also fail
|
||||||
|
if store.Consume(session.SessionID) {
|
||||||
|
t.Fatal("expected consume of expired session to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPendingStore_ReplayProtection(t *testing.T) {
|
||||||
|
store := NewPendingStore(DefaultSessionTTL)
|
||||||
|
|
||||||
|
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Add failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same nonce should be rejected
|
||||||
|
_, err = store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected duplicate nonce to be rejected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPendingStore_Cleanup(t *testing.T) {
|
||||||
|
store := NewPendingStore(50 * time.Millisecond)
|
||||||
|
|
||||||
|
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-cleanup")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.Count() != 1 {
|
||||||
|
t.Fatalf("expected count 1, got %d", store.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for expiry then trigger cleanup
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
store.cleanup()
|
||||||
|
|
||||||
|
if store.Count() != 0 {
|
||||||
|
t.Fatalf("expected count 0 after cleanup, got %d", store.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPendingStore_CleanupBackground(t *testing.T) {
|
||||||
|
store := NewPendingStore(50 * time.Millisecond)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
store.StartCleanup(ctx)
|
||||||
|
|
||||||
|
peerIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-bg-cleanup")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for expiry + cleanup interval
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
_, ok := store.QueryByPeerIP(peerIP)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected session to be cleaned up")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPendingStore_ConcurrentAccess(t *testing.T) {
|
||||||
|
store := NewPendingStore(DefaultSessionTTL)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
ip := netip.AddrFrom4([4]byte{100, 64, byte(i / 256), byte(i % 256)})
|
||||||
|
nonce := "nonce-" + string(rune(i+'A'))
|
||||||
|
if i >= 26 {
|
||||||
|
nonce = "nonce-" + string(rune(i-26+'a'))
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := store.Add(ip, "admin", ".", "user", nonce)
|
||||||
|
if err != nil {
|
||||||
|
return // nonce collision in test is expected
|
||||||
|
}
|
||||||
|
|
||||||
|
store.QueryByPeerIP(ip)
|
||||||
|
store.Consume(session.SessionID)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPendingStore_MultipleSessions(t *testing.T) {
|
||||||
|
store := NewPendingStore(DefaultSessionTTL)
|
||||||
|
|
||||||
|
ip1 := netip.MustParseAddr("100.64.0.1")
|
||||||
|
ip2 := netip.MustParseAddr("100.64.0.2")
|
||||||
|
|
||||||
|
s1, err := store.Add(ip1, "admin", ".", "user1", "nonce-a")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add s1 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s2, err := store.Add(ip2, "jdoe", "DOMAIN", "user2", "nonce-b")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Add s2 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query each
|
||||||
|
found1, ok := store.QueryByPeerIP(ip1)
|
||||||
|
if !ok || found1.SessionID != s1.SessionID {
|
||||||
|
t.Fatal("expected to find s1")
|
||||||
|
}
|
||||||
|
|
||||||
|
found2, ok := store.QueryByPeerIP(ip2)
|
||||||
|
if !ok || found2.SessionID != s2.SessionID {
|
||||||
|
t.Fatal("expected to find s2")
|
||||||
|
}
|
||||||
|
|
||||||
|
if found2.Domain != "DOMAIN" {
|
||||||
|
t.Errorf("expected domain DOMAIN, got %s", found2.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.Count() != 2 {
|
||||||
|
t.Errorf("expected count 2, got %d", store.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateNonce(t *testing.T) {
|
||||||
|
nonce1, err := GenerateNonce()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateNonce failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce2, err := GenerateNonce()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateNonce failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nonce1) != nonceLength*2 { // hex encoding doubles the length
|
||||||
|
t.Errorf("expected nonce length %d, got %d", nonceLength*2, len(nonce1))
|
||||||
|
}
|
||||||
|
|
||||||
|
if nonce1 == nonce2 {
|
||||||
|
t.Error("expected unique nonces")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWindowsUsername(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expectedUser string
|
||||||
|
expectedDomain string
|
||||||
|
}{
|
||||||
|
{"admin", "admin", "."},
|
||||||
|
{"DOMAIN\\admin", "admin", "DOMAIN"},
|
||||||
|
{"admin@domain.com", "admin", "domain.com"},
|
||||||
|
{".\\localuser", "localuser", "."},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
user, domain := parseWindowsUsername(tt.input)
|
||||||
|
if user != tt.expectedUser {
|
||||||
|
t.Errorf("parseWindowsUsername(%q) user = %q, want %q", tt.input, user, tt.expectedUser)
|
||||||
|
}
|
||||||
|
if domain != tt.expectedDomain {
|
||||||
|
t.Errorf("parseWindowsUsername(%q) domain = %q, want %q", tt.input, domain, tt.expectedDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
19
client/rdp/server/pipe_stub.go
Normal file
19
client/rdp/server/pipe_stub.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type stubPipeServer struct{}
|
||||||
|
|
||||||
|
func newPipeServer(_ *PendingStore) PipeServer {
|
||||||
|
return &stubPipeServer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubPipeServer) Start(_ context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubPipeServer) Stop() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
164
client/rdp/server/pipe_windows.go
Normal file
164
client/rdp/server/pipe_windows.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PipeName is the named pipe path used for IPC between the NetBird agent and
|
||||||
|
// the Credential Provider DLL.
|
||||||
|
PipeName = `\\.\pipe\netbird-rdp-auth`
|
||||||
|
|
||||||
|
// pipeSDDL restricts access to LOCAL_SYSTEM (SY) and Administrators (BA).
|
||||||
|
pipeSDDL = "D:P(A;;GA;;;SY)(A;;GA;;;BA)"
|
||||||
|
|
||||||
|
// maxPipeRequestSize is the maximum size of a pipe request in bytes.
|
||||||
|
maxPipeRequestSize = 4096
|
||||||
|
)
|
||||||
|
|
||||||
|
// windowsPipeServer implements the PipeServer interface for Windows.
|
||||||
|
type windowsPipeServer struct {
|
||||||
|
pending *PendingStore
|
||||||
|
listener net.Listener
|
||||||
|
mu sync.Mutex
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPipeServer(pending *PendingStore) PipeServer {
|
||||||
|
return &windowsPipeServer{
|
||||||
|
pending: pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsPipeServer) Start(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
|
cfg := &winio.PipeConfig{
|
||||||
|
SecurityDescriptor: pipeSDDL,
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := winio.ListenPipe(PipeName, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.listener = listener
|
||||||
|
|
||||||
|
go s.acceptLoop()
|
||||||
|
|
||||||
|
log.Infof("RDP named pipe server started on %s", PipeName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsPipeServer) Stop() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.cancel != nil {
|
||||||
|
s.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.listener != nil {
|
||||||
|
err := s.listener.Close()
|
||||||
|
s.listener = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsPipeServer) acceptLoop() {
|
||||||
|
for {
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if s.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("RDP pipe accept error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.handlePipeConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsPipeServer) handlePipeConnection(conn net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("RDP pipe close: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(io.LimitReader(conn, maxPipeRequestSize))
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("RDP pipe read: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req PipeRequest
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
log.Debugf("RDP pipe unmarshal: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp PipeResponse
|
||||||
|
|
||||||
|
switch req.Action {
|
||||||
|
case PipeActionQuery:
|
||||||
|
resp = s.handleQuery(req.RemoteIP)
|
||||||
|
case PipeActionConsume:
|
||||||
|
resp = s.handleConsume(req.SessionID)
|
||||||
|
default:
|
||||||
|
log.Debugf("RDP pipe unknown action: %s", req.Action)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respData, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("RDP pipe marshal response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write(respData); err != nil {
|
||||||
|
log.Debugf("RDP pipe write response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsPipeServer) handleQuery(remoteIP string) PipeResponse {
|
||||||
|
peerIP, err := parseAddr(remoteIP)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("RDP pipe invalid remote IP: %s", remoteIP)
|
||||||
|
return PipeResponse{Found: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
session, found := s.pending.QueryByPeerIP(peerIP)
|
||||||
|
if !found {
|
||||||
|
return PipeResponse{Found: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
return PipeResponse{
|
||||||
|
Found: true,
|
||||||
|
SessionID: session.SessionID,
|
||||||
|
OSUser: session.OSUsername,
|
||||||
|
Domain: session.Domain,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsPipeServer) handleConsume(sessionID string) PipeResponse {
|
||||||
|
if s.pending.Consume(sessionID) {
|
||||||
|
return PipeResponse{Found: true, SessionID: sessionID}
|
||||||
|
}
|
||||||
|
return PipeResponse{Found: false}
|
||||||
|
}
|
||||||
48
client/rdp/server/protocol.go
Normal file
48
client/rdp/server/protocol.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
// AuthRequest is the sideband authorization request sent by the connecting peer
|
||||||
|
// to the target peer's RDP auth server over the WireGuard tunnel.
|
||||||
|
type AuthRequest struct {
|
||||||
|
JWTToken string `json:"jwt_token"`
|
||||||
|
RequestedUser string `json:"requested_user"`
|
||||||
|
ClientPeerIP string `json:"client_peer_ip"`
|
||||||
|
Nonce string `json:"nonce"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthResponse is the sideband authorization response sent by the target peer
|
||||||
|
// back to the connecting peer.
|
||||||
|
type AuthResponse struct {
|
||||||
|
Status string `json:"status"` // "authorized" or "denied"
|
||||||
|
SessionID string `json:"session_id,omitempty"`
|
||||||
|
ExpiresAt int64 `json:"expires_at,omitempty"` // unix timestamp
|
||||||
|
OSUser string `json:"os_user,omitempty"`
|
||||||
|
Reason string `json:"reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipeRequest is the IPC request from the Credential Provider DLL to the NetBird agent
|
||||||
|
// via the named pipe.
|
||||||
|
type PipeRequest struct {
|
||||||
|
Action string `json:"action"` // "query_pending" or "consume"
|
||||||
|
RemoteIP string `json:"remote_ip"` // connecting peer's WG IP
|
||||||
|
SessionID string `json:"session_id,omitempty"` // for consume action
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipeResponse is the IPC response from the NetBird agent to the Credential Provider DLL.
|
||||||
|
type PipeResponse struct {
|
||||||
|
Found bool `json:"found"`
|
||||||
|
SessionID string `json:"session_id,omitempty"`
|
||||||
|
OSUser string `json:"os_user,omitempty"`
|
||||||
|
Domain string `json:"domain,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// StatusAuthorized indicates the RDP session was authorized.
|
||||||
|
StatusAuthorized = "authorized"
|
||||||
|
// StatusDenied indicates the RDP session was denied.
|
||||||
|
StatusDenied = "denied"
|
||||||
|
|
||||||
|
// PipeActionQuery queries for a pending session by remote IP.
|
||||||
|
PipeActionQuery = "query_pending"
|
||||||
|
// PipeActionConsume marks a pending session as consumed.
|
||||||
|
PipeActionConsume = "consume"
|
||||||
|
)
|
||||||
301
client/rdp/server/server.go
Normal file
301
client/rdp/server/server.go
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// InternalRDPAuthPort is the port the sideband auth server listens on.
|
||||||
|
InternalRDPAuthPort = 22338
|
||||||
|
|
||||||
|
// DefaultRDPAuthPort is the external port on the WireGuard interface (DNAT target).
|
||||||
|
DefaultRDPAuthPort = 22338
|
||||||
|
|
||||||
|
// maxRequestSize is the maximum size of an auth request in bytes.
|
||||||
|
maxRequestSize = 64 * 1024
|
||||||
|
|
||||||
|
// connectionTimeout is the timeout for a single auth connection.
|
||||||
|
connectionTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWTValidator validates JWT tokens and extracts user identity.
|
||||||
|
type JWTValidator interface {
|
||||||
|
ValidateAndExtract(token string) (userID string, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorizer checks if a user is authorized for RDP access.
|
||||||
|
type Authorizer interface {
|
||||||
|
Authorize(jwtUserID, osUsername string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server is the sideband RDP authorization server that listens on the WireGuard interface.
|
||||||
|
type Server struct {
|
||||||
|
listener net.Listener
|
||||||
|
pending *PendingStore
|
||||||
|
pipeServer PipeServer
|
||||||
|
jwtValidator JWTValidator
|
||||||
|
authorizer Authorizer
|
||||||
|
sshAuthorizer *sshauth.Authorizer // reuses SSH ACL for RDP access control
|
||||||
|
networkAddr netip.Prefix // WireGuard network for source IP validation
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipeServer is the interface for the named pipe IPC server (platform-specific).
|
||||||
|
type PipeServer interface {
|
||||||
|
Start(ctx context.Context) error
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds the configuration for the RDP auth server.
|
||||||
|
type Config struct {
|
||||||
|
JWTValidator JWTValidator
|
||||||
|
Authorizer Authorizer
|
||||||
|
NetworkAddr netip.Prefix
|
||||||
|
SessionTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new RDP sideband auth server.
|
||||||
|
func New(cfg *Config) *Server {
|
||||||
|
ttl := cfg.SessionTTL
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = DefaultSessionTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
pending := NewPendingStore(ttl)
|
||||||
|
|
||||||
|
return &Server{
|
||||||
|
pending: pending,
|
||||||
|
pipeServer: newPipeServer(pending),
|
||||||
|
jwtValidator: cfg.JWTValidator,
|
||||||
|
authorizer: cfg.Authorizer,
|
||||||
|
sshAuthorizer: sshauth.NewAuthorizer(),
|
||||||
|
networkAddr: cfg.NetworkAddr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRDPAuth updates the RDP authorization config (reuses SSH ACL).
|
||||||
|
func (s *Server) UpdateRDPAuth(config *sshauth.Config) {
|
||||||
|
s.sshAuthorizer.Update(config)
|
||||||
|
log.Debugf("RDP auth: updated authorization config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins listening for sideband auth requests on the given address.
|
||||||
|
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.listener != nil {
|
||||||
|
return errors.New("RDP auth server already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
|
listenAddr := net.TCPAddrFromAddrPort(addr)
|
||||||
|
listener, err := net.ListenTCP("tcp", listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
s.listener = listener
|
||||||
|
|
||||||
|
s.pending.StartCleanup(s.ctx)
|
||||||
|
|
||||||
|
if s.pipeServer != nil {
|
||||||
|
if err := s.pipeServer.Start(s.ctx); err != nil {
|
||||||
|
log.Warnf("failed to start RDP named pipe server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.acceptLoop()
|
||||||
|
|
||||||
|
log.Infof("RDP sideband auth server started on %s", addr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop shuts down the server and cleans up resources.
|
||||||
|
func (s *Server) Stop() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.cancel != nil {
|
||||||
|
s.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.pipeServer != nil {
|
||||||
|
if err := s.pipeServer.Stop(); err != nil {
|
||||||
|
log.Warnf("failed to stop RDP named pipe server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.listener != nil {
|
||||||
|
err := s.listener.Close()
|
||||||
|
s.listener = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("close listener: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("RDP sideband auth server stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingStore returns the pending session store (for testing/named pipe access).
|
||||||
|
func (s *Server) GetPendingStore() *PendingStore {
|
||||||
|
return s.pending
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) acceptLoop() {
|
||||||
|
for {
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if s.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("RDP auth accept error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.handleConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleConnection(conn net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("RDP auth close connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Now().Add(connectionTimeout)); err != nil {
|
||||||
|
log.Debugf("RDP auth set deadline: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate source IP is from WireGuard network
|
||||||
|
remoteAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("RDP auth parse remote addr: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.networkAddr.Contains(remoteAddr.Addr()) {
|
||||||
|
log.Warnf("RDP auth rejected connection from non-WG address: %s", remoteAddr.Addr())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read request
|
||||||
|
data, err := io.ReadAll(io.LimitReader(conn, maxRequestSize))
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("RDP auth read request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req AuthRequest
|
||||||
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
|
log.Debugf("RDP auth unmarshal request: %v", err)
|
||||||
|
s.sendResponse(conn, &AuthResponse{Status: StatusDenied, Reason: "invalid request format"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := s.processAuthRequest(remoteAddr.Addr(), &req)
|
||||||
|
s.sendResponse(conn, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) processAuthRequest(peerIP netip.Addr, req *AuthRequest) *AuthResponse {
|
||||||
|
// Validate JWT
|
||||||
|
if s.jwtValidator == nil {
|
||||||
|
// No JWT validation configured - for POC, accept all requests from WG peers
|
||||||
|
log.Warnf("RDP auth: no JWT validator configured, accepting request from %s", peerIP)
|
||||||
|
return s.createSession(peerIP, req, "no-jwt-validation")
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := s.jwtValidator.ValidateAndExtract(req.JWTToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("RDP auth JWT validation failed for %s: %v", peerIP, err)
|
||||||
|
return &AuthResponse{Status: StatusDenied, Reason: "JWT validation failed"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check authorization - try explicit authorizer first, then SSH ACL
|
||||||
|
if s.authorizer != nil {
|
||||||
|
if _, err := s.authorizer.Authorize(userID, req.RequestedUser); err != nil {
|
||||||
|
log.Warnf("RDP auth denied for user %s -> %s: %v", userID, req.RequestedUser, err)
|
||||||
|
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
|
||||||
|
}
|
||||||
|
} else if s.sshAuthorizer != nil {
|
||||||
|
if _, err := s.sshAuthorizer.Authorize(userID, req.RequestedUser); err != nil {
|
||||||
|
log.Warnf("RDP auth denied (SSH ACL) for user %s -> %s: %v", userID, req.RequestedUser, err)
|
||||||
|
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.createSession(peerIP, req, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) createSession(peerIP netip.Addr, req *AuthRequest, jwtUserID string) *AuthResponse {
|
||||||
|
// Parse domain from requested user (DOMAIN\user or user@domain)
|
||||||
|
osUser, domain := parseWindowsUsername(req.RequestedUser)
|
||||||
|
|
||||||
|
session, err := s.pending.Add(peerIP, osUser, domain, jwtUserID, req.Nonce)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("RDP auth create session failed: %v", err)
|
||||||
|
return &AuthResponse{Status: StatusDenied, Reason: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthResponse{
|
||||||
|
Status: StatusAuthorized,
|
||||||
|
SessionID: session.SessionID,
|
||||||
|
ExpiresAt: session.ExpiresAt.Unix(),
|
||||||
|
OSUser: session.OSUsername,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) sendResponse(conn net.Conn, resp *AuthResponse) {
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("RDP auth marshal response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write(data); err != nil {
|
||||||
|
log.Debugf("RDP auth write response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseWindowsUsername extracts username and domain from Windows username formats.
|
||||||
|
// Supports DOMAIN\username, username@domain, and plain username.
|
||||||
|
func parseWindowsUsername(fullUsername string) (username, domain string) {
|
||||||
|
for i := len(fullUsername) - 1; i >= 0; i-- {
|
||||||
|
if fullUsername[i] == '\\' {
|
||||||
|
return fullUsername[i+1:], fullUsername[:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx := indexOf(fullUsername, '@'); idx != -1 {
|
||||||
|
return fullUsername[:idx], fullUsername[idx+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullUsername, "."
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexOf(s string, c byte) int {
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] == c {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
@@ -366,6 +366,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
|||||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||||
|
config.ServerRDPAllowed = msg.ServerRDPAllowed
|
||||||
config.NetworkMonitor = msg.NetworkMonitor
|
config.NetworkMonitor = msg.NetworkMonitor
|
||||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||||
@@ -1514,6 +1515,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
Mtu: int64(cfg.MTU),
|
Mtu: int64(cfg.MTU),
|
||||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||||
|
ServerRDPAllowed: cfg.ServerRDPAllowed != nil && *cfg.ServerRDPAllowed,
|
||||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
rosenpassEnabled := true
|
rosenpassEnabled := true
|
||||||
rosenpassPermissive := true
|
rosenpassPermissive := true
|
||||||
serverSSHAllowed := true
|
serverSSHAllowed := true
|
||||||
|
serverRDPAllowed := true
|
||||||
interfaceName := "utun100"
|
interfaceName := "utun100"
|
||||||
wireguardPort := int64(51820)
|
wireguardPort := int64(51820)
|
||||||
preSharedKey := "test-psk"
|
preSharedKey := "test-psk"
|
||||||
@@ -82,6 +83,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
RosenpassEnabled: &rosenpassEnabled,
|
RosenpassEnabled: &rosenpassEnabled,
|
||||||
RosenpassPermissive: &rosenpassPermissive,
|
RosenpassPermissive: &rosenpassPermissive,
|
||||||
ServerSSHAllowed: &serverSSHAllowed,
|
ServerSSHAllowed: &serverSSHAllowed,
|
||||||
|
ServerRDPAllowed: &serverRDPAllowed,
|
||||||
InterfaceName: &interfaceName,
|
InterfaceName: &interfaceName,
|
||||||
WireguardPort: &wireguardPort,
|
WireguardPort: &wireguardPort,
|
||||||
OptionalPreSharedKey: &preSharedKey,
|
OptionalPreSharedKey: &preSharedKey,
|
||||||
@@ -125,6 +127,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||||
|
require.NotNil(t, cfg.ServerRDPAllowed)
|
||||||
|
require.Equal(t, serverRDPAllowed, *cfg.ServerRDPAllowed)
|
||||||
require.Equal(t, interfaceName, cfg.WgIface)
|
require.Equal(t, interfaceName, cfg.WgIface)
|
||||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||||
@@ -176,6 +180,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
|||||||
"RosenpassEnabled": true,
|
"RosenpassEnabled": true,
|
||||||
"RosenpassPermissive": true,
|
"RosenpassPermissive": true,
|
||||||
"ServerSSHAllowed": true,
|
"ServerSSHAllowed": true,
|
||||||
|
"ServerRDPAllowed": true,
|
||||||
"InterfaceName": true,
|
"InterfaceName": true,
|
||||||
"WireguardPort": true,
|
"WireguardPort": true,
|
||||||
"OptionalPreSharedKey": true,
|
"OptionalPreSharedKey": true,
|
||||||
@@ -236,6 +241,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
|||||||
"enable-rosenpass": "RosenpassEnabled",
|
"enable-rosenpass": "RosenpassEnabled",
|
||||||
"rosenpass-permissive": "RosenpassPermissive",
|
"rosenpass-permissive": "RosenpassPermissive",
|
||||||
"allow-server-ssh": "ServerSSHAllowed",
|
"allow-server-ssh": "ServerSSHAllowed",
|
||||||
|
"allow-server-rdp": "ServerRDPAllowed",
|
||||||
"interface-name": "InterfaceName",
|
"interface-name": "InterfaceName",
|
||||||
"wireguard-port": "WireguardPort",
|
"wireguard-port": "WireguardPort",
|
||||||
"preshared-key": "OptionalPreSharedKey",
|
"preshared-key": "OptionalPreSharedKey",
|
||||||
|
|||||||
Reference in New Issue
Block a user