[client] Add opt-in --allow-server-rdp flag, reuse SSH ACL, dynamic DLL registration

- Add ServerRDPAllowed field to daemon proto, EngineConfig, and profile config
- Add --allow-server-rdp flag to `netbird up` (opt-in, defaults to false)
- Wire RDP server start/stop in engine based on the flag
- Reuse SSH ACL (SSHAuth proto) for RDP authorization via sshauth.Authorizer
- Register/unregister credential provider COM DLL dynamically when flag is toggled
- Ship DLL alongside netbird.exe, register via regsvr32 at runtime (not install time)
- Update SetConfig tests to cover the new field

https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
This commit is contained in:
Claude
2026-04-11 17:48:53 +00:00
parent c5186f1483
commit 4949ca6194
13 changed files with 252 additions and 10 deletions

View File

@@ -25,17 +25,24 @@ import (
"github.com/netbirdio/netbird/util"
)
const (
serverRDPAllowedFlag = "allow-server-rdp"
)
var (
rdpUsername string
rdpHost string
rdpNoBrowser bool
rdpNoCache bool
rdpUsername string
rdpHost string
rdpNoBrowser bool
rdpNoCache bool
serverRDPAllowed bool
)
func init() {
rdpCmd.PersistentFlags().StringVarP(&rdpUsername, "user", "u", "", "Windows username on remote peer")
rdpCmd.PersistentFlags().BoolVar(&rdpNoBrowser, noBrowserFlag, false, noBrowserDesc)
rdpCmd.PersistentFlags().BoolVar(&rdpNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
upCmd.PersistentFlags().BoolVar(&serverRDPAllowed, serverRDPAllowedFlag, false, "Allow RDP passthrough on peer (passwordless RDP via credential provider)")
}
var rdpCmd = &cobra.Command{

View File

@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
req.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -458,6 +461,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
ic.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
@@ -582,6 +588,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
loginRequest.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot

View File

@@ -543,6 +543,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerRDPAllowed: config.ServerRDPAllowed != nil && *config.ServerRDPAllowed,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,

View File

@@ -117,6 +117,7 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerRDPAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -1037,6 +1038,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateRDP(); err != nil {
log.Warnf("failed handling RDP server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
@@ -1324,6 +1329,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
// Reuse SSH ACL for RDP authorization
e.updateRDPServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store

View File

@@ -10,13 +10,17 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type rdpServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetPendingStore() *rdpserver.PendingStore
UpdateRDPAuth(config *sshauth.Config)
}
func (e *Engine) setupRDPPortRedirection() error {
@@ -57,6 +61,28 @@ func (e *Engine) cleanupRDPPortRedirection() error {
return nil
}
// updateRDP handles starting/stopping the RDP server based on the config flag.
func (e *Engine) updateRDP() error {
if !e.config.ServerRDPAllowed {
if e.rdpServer != nil {
log.Info("RDP passthrough disabled, stopping RDP auth server")
}
return e.stopRDPServer()
}
if e.config.BlockInbound {
log.Info("RDP server is disabled because inbound connections are blocked")
return e.stopRDPServer()
}
if e.rdpServer != nil {
log.Debug("RDP auth server is already running")
return nil
}
return e.startRDPServer()
}
func (e *Engine) startRDPServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
@@ -92,6 +118,12 @@ func (e *Engine) startRDPServer() error {
log.Warnf("failed to setup RDP auth port redirection: %v", err)
}
// Register the credential provider DLL dynamically (Windows only)
if err := rdpserver.RegisterCredentialProvider(); err != nil {
log.Warnf("failed to register RDP credential provider (passwordless RDP will not work): %v", err)
}
log.Info("RDP passthrough enabled")
return nil
}
@@ -113,6 +145,11 @@ func (e *Engine) stopRDPServer() error {
}
}
// Unregister the credential provider DLL (Windows only)
if err := rdpserver.UnregisterCredentialProvider(); err != nil {
log.Warnf("failed to unregister RDP credential provider: %v", err)
}
log.Info("stopping RDP auth server")
err := e.rdpServer.Stop()
e.rdpServer = nil
@@ -121,3 +158,34 @@ func (e *Engine) stopRDPServer() error {
}
return nil
}
// updateRDPServerAuth reuses the SSH authorization config for RDP access control.
// This means the same user/machine-user mappings that control SSH access also control RDP.
func (e *Engine) updateRDPServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil || e.rdpServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping RDP server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.rdpServer.UpdateRDPAuth(authConfig)
}

View File

@@ -64,6 +64,7 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerRDPAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -114,6 +115,7 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerRDPAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -415,6 +417,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerRDPAllowed != nil {
if config.ServerRDPAllowed == nil || *input.ServerRDPAllowed != *config.ServerRDPAllowed {
if *input.ServerRDPAllowed {
log.Infof("enabling RDP passthrough")
} else {
log.Infof("disabling RDP passthrough")
}
config.ServerRDPAllowed = input.ServerRDPAllowed
updated = true
}
} else if config.ServerRDPAllowed == nil {
config.ServerRDPAllowed = util.False()
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")

View File

@@ -472,6 +472,7 @@ type LoginRequest struct {
EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed *bool `protobuf:"varint,40,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -780,6 +781,13 @@ func (x *LoginRequest) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *LoginRequest) GetServerRDPAllowed() bool {
if x != nil && x.ServerRDPAllowed != nil {
return *x.ServerRDPAllowed
}
return false
}
type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -1312,6 +1320,7 @@ type GetConfigResponse struct {
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed bool `protobuf:"varint,27,opt,name=serverRDPAllowed,proto3" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1528,6 +1537,13 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *GetConfigResponse) GetServerRDPAllowed() bool {
if x != nil {
return x.ServerRDPAllowed
}
return false
}
// PeerState contains the latest state of a peer
type PeerState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -4139,6 +4155,7 @@ type SetConfigRequest struct {
EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed *bool `protobuf:"varint,35,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4411,6 +4428,13 @@ func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *SetConfigRequest) GetServerRDPAllowed() bool {
if x != nil && x.ServerRDPAllowed != nil {
return *x.ServerRDPAllowed
}
return false
}
type SetConfigResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields

View File

@@ -209,6 +209,8 @@ message LoginRequest {
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool serverRDPAllowed = 40;
}
message LoginResponse {
@@ -316,6 +318,8 @@ message GetConfigResponse {
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
bool serverRDPAllowed = 27;
}
// PeerState contains the latest state of a peer
@@ -677,6 +681,8 @@ message SetConfigRequest {
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool serverRDPAllowed = 35;
}
message SetConfigResponse{}

View File

@@ -0,0 +1,13 @@
//go:build !windows
package server
// RegisterCredentialProvider is a no-op on non-Windows platforms.
func RegisterCredentialProvider() error {
return nil
}
// UnregisterCredentialProvider is a no-op on non-Windows platforms.
func UnregisterCredentialProvider() error {
return nil
}

View File

@@ -0,0 +1,66 @@
//go:build windows
package server
import (
"fmt"
"os"
"os/exec"
"path/filepath"
log "github.com/sirupsen/logrus"
)
const (
// credProvDLLName is the filename of the credential provider DLL.
credProvDLLName = "netbird_credprov.dll"
)
// RegisterCredentialProvider registers the NetBird Credential Provider COM DLL
// using regsvr32. The DLL must be shipped alongside the NetBird executable.
func RegisterCredentialProvider() error {
dllPath, err := findCredProvDLL()
if err != nil {
return fmt.Errorf("find credential provider DLL: %w", err)
}
cmd := exec.Command("regsvr32", "/s", dllPath)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("regsvr32 %s: %w (output: %s)", dllPath, err, string(output))
}
log.Infof("registered RDP credential provider: %s", dllPath)
return nil
}
// UnregisterCredentialProvider unregisters the NetBird Credential Provider COM DLL.
func UnregisterCredentialProvider() error {
dllPath, err := findCredProvDLL()
if err != nil {
log.Debugf("credential provider DLL not found for unregistration: %v", err)
return nil
}
cmd := exec.Command("regsvr32", "/s", "/u", dllPath)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("regsvr32 /u %s: %w (output: %s)", dllPath, err, string(output))
}
log.Infof("unregistered RDP credential provider: %s", dllPath)
return nil
}
// findCredProvDLL locates the credential provider DLL next to the running executable.
func findCredProvDLL() (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("get executable path: %w", err)
}
dllPath := filepath.Join(filepath.Dir(exePath), credProvDLLName)
if _, err := os.Stat(dllPath); err != nil {
return "", fmt.Errorf("DLL not found at %s: %w", dllPath, err)
}
return dllPath, nil
}

View File

@@ -12,6 +12,8 @@ import (
"time"
log "github.com/sirupsen/logrus"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
)
const (
@@ -45,6 +47,7 @@ type Server struct {
pipeServer PipeServer
jwtValidator JWTValidator
authorizer Authorizer
sshAuthorizer *sshauth.Authorizer // reuses SSH ACL for RDP access control
networkAddr netip.Prefix // WireGuard network for source IP validation
mu sync.Mutex
@@ -76,14 +79,21 @@ func New(cfg *Config) *Server {
pending := NewPendingStore(ttl)
return &Server{
pending: pending,
pipeServer: newPipeServer(pending),
jwtValidator: cfg.JWTValidator,
authorizer: cfg.Authorizer,
networkAddr: cfg.NetworkAddr,
pending: pending,
pipeServer: newPipeServer(pending),
jwtValidator: cfg.JWTValidator,
authorizer: cfg.Authorizer,
sshAuthorizer: sshauth.NewAuthorizer(),
networkAddr: cfg.NetworkAddr,
}
}
// UpdateRDPAuth updates the RDP authorization config (reuses SSH ACL).
func (s *Server) UpdateRDPAuth(config *sshauth.Config) {
s.sshAuthorizer.Update(config)
log.Debugf("RDP auth: updated authorization config")
}
// Start begins listening for sideband auth requests on the given address.
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
@@ -219,12 +229,17 @@ func (s *Server) processAuthRequest(peerIP netip.Addr, req *AuthRequest) *AuthRe
return &AuthResponse{Status: StatusDenied, Reason: "JWT validation failed"}
}
// Check authorization
// Check authorization - try explicit authorizer first, then SSH ACL
if s.authorizer != nil {
if _, err := s.authorizer.Authorize(userID, req.RequestedUser); err != nil {
log.Warnf("RDP auth denied for user %s -> %s: %v", userID, req.RequestedUser, err)
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
}
} else if s.sshAuthorizer != nil {
if _, err := s.sshAuthorizer.Authorize(userID, req.RequestedUser); err != nil {
log.Warnf("RDP auth denied (SSH ACL) for user %s -> %s: %v", userID, req.RequestedUser, err)
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
}
}
return s.createSession(peerIP, req, userID)

View File

@@ -366,6 +366,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerRDPAllowed = msg.ServerRDPAllowed
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -1514,6 +1515,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerRDPAllowed: cfg.ServerRDPAllowed != nil && *cfg.ServerRDPAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,

View File

@@ -58,6 +58,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverRDPAllowed := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -82,6 +83,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerRDPAllowed: &serverRDPAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -125,6 +127,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerRDPAllowed)
require.Equal(t, serverRDPAllowed, *cfg.ServerRDPAllowed)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -176,6 +180,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerRDPAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -236,6 +241,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-rdp": "ServerRDPAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",