Compare commits

...

12 Commits

Author SHA1 Message Date
Claude
b621a2628b [client] Use port 22338 for RDP sideband auth server
https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
2026-04-11 18:21:22 +00:00
Claude
4949ca6194 [client] Add opt-in --allow-server-rdp flag, reuse SSH ACL, dynamic DLL registration
- Add ServerRDPAllowed field to daemon proto, EngineConfig, and profile config
- Add --allow-server-rdp flag to `netbird up` (opt-in, defaults to false)
- Wire RDP server start/stop in engine based on the flag
- Reuse SSH ACL (SSHAuth proto) for RDP authorization via sshauth.Authorizer
- Register/unregister credential provider COM DLL dynamically when flag is toggled
- Ship DLL alongside netbird.exe, register via regsvr32 at runtime (not install time)
- Update SetConfig tests to cover the new field

https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
2026-04-11 17:48:53 +00:00
Claude
c5186f1483 [client] Add RDP token passthrough for passwordless Windows Remote Desktop
Implement sideband authorization and credential provider architecture for
passwordless RDP access to Windows peers via NetBird.

Go components:
- Sideband RDP auth server (TCP on WG interface, port 3390/22023)
- Pending session store with TTL expiry and replay protection
- Named pipe IPC server (\\.\pipe\netbird-rdp-auth) for credential provider
- Sideband client for connecting peer to request authorization
- CLI command `netbird rdp [user@]host` with JWT auth flow
- Engine integration with DNAT port redirection

Rust credential provider DLL (client/rdp/credprov/):
- COM DLL implementing ICredentialProvider + ICredentialProviderCredential
- Loaded by Windows LogonUI.exe at the RDP login screen
- Queries NetBird agent via named pipe for pending sessions
- Performs S4U logon (LsaLogonUser) for passwordless Windows token creation
- Self-registration via regsvr32 (DllRegisterServer/DllUnregisterServer)

https://claude.ai/code/session_01C38bCDyYzLgxYLVwJkcUng
2026-04-11 17:15:42 +00:00
Pascal Fischer
5259e5df51 [management] add domain and service cleanup migration (#5850) 2026-04-11 12:00:40 +02:00
Zoltan Papp
ebd78e0122 [client] Update RaceDial to accept context for improved cancellation handling (#5849) 2026-04-10 20:51:04 +02:00
Pascal Fischer
cf86b9a528 [management] enable access log cleanup by default (#5842) 2026-04-10 17:07:27 +02:00
Pascal Fischer
ee588e1536 Revert "[management] allow local routing peer resource (#5814)" (#5847) 2026-04-10 14:53:47 +02:00
Pascal Fischer
2a8aacc5c9 [management] allow local routing peer resource (#5814) 2026-04-10 13:08:21 +02:00
Pascal Fischer
15709bc666 [management] update account delete with proper proxy domain and service cleanup (#5817) 2026-04-10 13:08:04 +02:00
Pascal Fischer
789b4113fe [misc] update dashboards (#5840) 2026-04-10 12:15:58 +02:00
Viktor Liu
d2cdc0efec [client] Use native firewall for peer ACLs in userspace WireGuard mode (#5668) 2026-04-10 09:12:13 +08:00
Pascal Fischer
ee343d5d77 [management] use sql null vars (#5844) 2026-04-09 18:12:38 +02:00
57 changed files with 3629 additions and 105 deletions

276
client/cmd/rdp.go Normal file
View File

@@ -0,0 +1,276 @@
package cmd
import (
"context"
"errors"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
rdpclient "github.com/netbirdio/netbird/client/rdp/client"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
const (
serverRDPAllowedFlag = "allow-server-rdp"
)
var (
rdpUsername string
rdpHost string
rdpNoBrowser bool
rdpNoCache bool
serverRDPAllowed bool
)
func init() {
rdpCmd.PersistentFlags().StringVarP(&rdpUsername, "user", "u", "", "Windows username on remote peer")
rdpCmd.PersistentFlags().BoolVar(&rdpNoBrowser, noBrowserFlag, false, noBrowserDesc)
rdpCmd.PersistentFlags().BoolVar(&rdpNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
upCmd.PersistentFlags().BoolVar(&serverRDPAllowed, serverRDPAllowedFlag, false, "Allow RDP passthrough on peer (passwordless RDP via credential provider)")
}
var rdpCmd = &cobra.Command{
Use: "rdp [flags] [user@]host",
Short: "Connect to a NetBird peer via RDP (passwordless)",
Long: `Connect to a NetBird peer using Remote Desktop Protocol with token-based
passwordless authentication. The target peer must have RDP passthrough enabled.
This command:
1. Obtains a JWT token via OIDC authentication
2. Sends the token to the target peer's sideband auth service
3. If authorized, launches mstsc.exe to connect
Examples:
netbird rdp peer-hostname
netbird rdp administrator@peer-hostname
netbird rdp --user admin peer-hostname`,
Args: cobra.MinimumNArgs(1),
RunE: rdpFn,
}
func rdpFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
// Parse user@host
if err := parseRDPHostArg(args[0]); err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
rdpCtx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runRDP(rdpCtx, cmd); err != nil {
errCh <- err
}
cancel()
}()
select {
case <-sig:
cancel()
<-rdpCtx.Done()
return nil
case err := <-errCh:
return err
case <-rdpCtx.Done():
}
return nil
}
func parseRDPHostArg(arg string) error {
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
if rdpUsername == "" {
rdpUsername = parts[0]
}
rdpHost = parts[1]
} else {
rdpHost = arg
}
if rdpUsername == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
rdpUsername = sudoUser
} else if currentUser, err := user.Current(); err == nil {
rdpUsername = currentUser.Username
} else {
rdpUsername = "Administrator"
}
}
return nil
}
func runRDP(ctx context.Context, cmd *cobra.Command) error {
// Connect to daemon
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() { _ = grpcConn.Close() }()
daemonClient := proto.NewDaemonServiceClient(grpcConn)
// Resolve peer IP
peerIP, err := resolvePeerIP(ctx, daemonClient, rdpHost)
if err != nil {
return fmt.Errorf("resolve peer %s: %w", rdpHost, err)
}
cmd.Printf("Connecting to %s@%s (%s)...\n", rdpUsername, rdpHost, peerIP)
// Obtain JWT token
hint := profilemanager.GetLoginHint()
var browserOpener func(string) error
if !rdpNoBrowser {
browserOpener = util.OpenBrowser
}
jwtToken, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !rdpNoCache, hint, browserOpener)
if err != nil {
return fmt.Errorf("JWT authentication: %w", err)
}
log.Debug("JWT authentication successful")
cmd.Println("Authenticated. Requesting RDP access...")
// Generate nonce for replay protection
nonce, err := rdpserver.GenerateNonce()
if err != nil {
return fmt.Errorf("generate nonce: %w", err)
}
// Send sideband auth request
authClient := rdpclient.New()
authAddr := net.JoinHostPort(peerIP, fmt.Sprintf("%d", rdpserver.DefaultRDPAuthPort))
resp, err := authClient.RequestAuth(ctx, authAddr, &rdpserver.AuthRequest{
JWTToken: jwtToken,
RequestedUser: rdpUsername,
ClientPeerIP: "", // will be filled by the server from the connection
Nonce: nonce,
})
if err != nil {
cmd.Printf("Failed to authorize RDP session with %s\n", rdpHost)
cmd.Printf("\nTroubleshooting:\n")
cmd.Printf(" 1. Check connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify RDP passthrough is enabled on the target peer\n")
return fmt.Errorf("sideband auth: %w", err)
}
if resp.Status != rdpserver.StatusAuthorized {
return fmt.Errorf("RDP access denied: %s", resp.Reason)
}
cmd.Printf("RDP access authorized (session: %s, user: %s)\n", resp.SessionID, resp.OSUser)
cmd.Printf("Launching Remote Desktop client...\n")
// Launch mstsc.exe (platform-specific)
if err := launchRDPClient(peerIP); err != nil {
return fmt.Errorf("launch RDP client: %w", err)
}
return nil
}
// resolvePeerIP resolves a peer hostname/FQDN to its WireGuard IP address
// by querying the daemon for the current peer status.
func resolvePeerIP(ctx context.Context, client proto.DaemonServiceClient, peerAddress string) (string, error) {
statusResp, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return "", fmt.Errorf("get daemon status: %w", err)
}
if statusResp.GetFullStatus() == nil {
return "", errors.New("daemon returned empty status")
}
for _, peer := range statusResp.GetFullStatus().GetPeers() {
if matchesPeer(peer, peerAddress) {
ip := peer.GetIP()
if ip == "" {
continue
}
// Strip CIDR suffix if present
if idx := strings.Index(ip, "/"); idx != -1 {
ip = ip[:idx]
}
return ip, nil
}
}
// If not found as a peer name, try as a direct IP
if addr, err := net.ResolveIPAddr("ip", peerAddress); err == nil {
return addr.String(), nil
}
return "", fmt.Errorf("peer %q not found in network", peerAddress)
}
func matchesPeer(peer *proto.PeerState, address string) bool {
address = strings.ToLower(address)
if strings.EqualFold(peer.GetFqdn(), address) {
return true
}
// Match against FQDN without trailing dot
fqdn := strings.TrimSuffix(peer.GetFqdn(), ".")
if strings.EqualFold(fqdn, address) {
return true
}
// Match against short hostname (first part of FQDN)
if parts := strings.SplitN(fqdn, ".", 2); len(parts) > 0 {
if strings.EqualFold(parts[0], address) {
return true
}
}
// Match against IP
ip := peer.GetIP()
if idx := strings.Index(ip, "/"); idx != -1 {
ip = ip[:idx]
}
if ip == address {
return true
}
return false
}

13
client/cmd/rdp_stub.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows
package cmd
import "fmt"
// launchRDPClient is a stub for non-Windows platforms.
func launchRDPClient(peerIP string) error {
fmt.Printf("RDP session authorized for %s\n", peerIP)
fmt.Println("Note: mstsc.exe is only available on Windows.")
fmt.Printf("Use any RDP client to connect to %s:3389\n", peerIP)
return nil
}

34
client/cmd/rdp_windows.go Normal file
View File

@@ -0,0 +1,34 @@
//go:build windows
package cmd
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
// launchRDPClient launches the native Windows Remote Desktop client (mstsc.exe).
func launchRDPClient(peerIP string) error {
mstscPath, err := exec.LookPath("mstsc.exe")
if err != nil {
return fmt.Errorf("mstsc.exe not found: %w", err)
}
cmd := exec.Command(mstscPath, fmt.Sprintf("/v:%s", peerIP))
if err := cmd.Start(); err != nil {
return fmt.Errorf("start mstsc.exe: %w", err)
}
log.Debugf("launched mstsc.exe (PID %d) connecting to %s", cmd.Process.Pid, peerIP)
// Don't wait for mstsc to exit - it runs independently
go func() {
if err := cmd.Wait(); err != nil {
log.Debugf("mstsc.exe exited: %v", err)
}
}()
return nil
}

View File

@@ -150,6 +150,7 @@ func init() {
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(rdpCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)

View File

@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
req.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -458,6 +461,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
ic.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
@@ -582,6 +588,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverRDPAllowedFlag).Changed {
loginRequest.ServerRDPAllowed = &serverRDPAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"os"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
@@ -35,20 +36,27 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
log.Info("forcing userspace firewall")
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() {
return fm, err
}
// Fall back to the userspace packet filter if native is unavailable
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
return fm, nil
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
@@ -160,3 +168,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
func forceUserspaceFirewall() bool {
val := os.Getenv(EnvForceUserspaceFirewall)
if val == "" {
return false
}
force, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
return false
}
return force
}

View File

@@ -7,6 +7,12 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
// native iptables/nftables is available. This only applies when the WireGuard interface
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
// kernel netfilter rules.
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string

View File

@@ -33,7 +33,6 @@ type Manager struct {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Create iptables firewall manager
@@ -64,10 +63,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)
@@ -203,12 +201,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},

View File

@@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)

View File

@@ -9,10 +9,9 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex

View File

@@ -40,7 +40,6 @@ func getTableName() string {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Manager of iptables firewall
@@ -106,10 +105,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -205,12 +203,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
m.mutex.Lock()
defer m.mutex.Unlock()

View File

@@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface

View File

@@ -8,10 +8,9 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}

View File

@@ -19,6 +19,9 @@ import (
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
@@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) {
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
@@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) {
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
@@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
@@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()

View File

@@ -543,6 +543,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerRDPAllowed: config.ServerRDPAllowed != nil && *config.ServerRDPAllowed,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,

View File

@@ -117,6 +117,7 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerRDPAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -197,6 +198,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
rdpServer rdpServer
statusRecorder *peer.Status
@@ -1036,6 +1038,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateRDP(); err != nil {
log.Warnf("failed handling RDP server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
@@ -1323,6 +1329,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
// Reuse SSH ACL for RDP authorization
e.updateRDPServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store

View File

@@ -0,0 +1,191 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type rdpServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetPendingStore() *rdpserver.PendingStore
UpdateRDPAuth(config *sshauth.Config)
}
func (e *Engine) setupRDPPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
return fmt.Errorf("add RDP auth port redirection: %w", err)
}
log.Infof("RDP auth port redirection enabled: %s:%d -> %s:%d",
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
return nil
}
func (e *Engine) cleanupRDPPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, rdpserver.DefaultRDPAuthPort, rdpserver.InternalRDPAuthPort); err != nil {
return fmt.Errorf("remove RDP auth port redirection: %w", err)
}
log.Debugf("RDP auth port redirection removed: %s:%d -> %s:%d",
localAddr, rdpserver.DefaultRDPAuthPort, localAddr, rdpserver.InternalRDPAuthPort)
return nil
}
// updateRDP handles starting/stopping the RDP server based on the config flag.
func (e *Engine) updateRDP() error {
if !e.config.ServerRDPAllowed {
if e.rdpServer != nil {
log.Info("RDP passthrough disabled, stopping RDP auth server")
}
return e.stopRDPServer()
}
if e.config.BlockInbound {
log.Info("RDP server is disabled because inbound connections are blocked")
return e.stopRDPServer()
}
if e.rdpServer != nil {
log.Debug("RDP auth server is already running")
return nil
}
return e.startRDPServer()
}
func (e *Engine) startRDPServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
wgAddr := e.wgInterface.Address()
cfg := &rdpserver.Config{
NetworkAddr: wgAddr.Network,
}
server := rdpserver.New(cfg)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, rdpserver.InternalRDPAuthPort)
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start RDP auth server: %w", err)
}
e.rdpServer = server
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
log.Debugf("registered RDP auth service with netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
}
}
if err := e.setupRDPPortRedirection(); err != nil {
log.Warnf("failed to setup RDP auth port redirection: %v", err)
}
// Register the credential provider DLL dynamically (Windows only)
if err := rdpserver.RegisterCredentialProvider(); err != nil {
log.Warnf("failed to register RDP credential provider (passwordless RDP will not work): %v", err)
}
log.Info("RDP passthrough enabled")
return nil
}
func (e *Engine) stopRDPServer() error {
if e.rdpServer == nil {
return nil
}
if err := e.cleanupRDPPortRedirection(); err != nil {
log.Warnf("failed to cleanup RDP auth port redirection: %v", err)
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, rdpserver.InternalRDPAuthPort)
log.Debugf("unregistered RDP auth service from netstack for TCP:%d", rdpserver.InternalRDPAuthPort)
}
}
// Unregister the credential provider DLL (Windows only)
if err := rdpserver.UnregisterCredentialProvider(); err != nil {
log.Warnf("failed to unregister RDP credential provider: %v", err)
}
log.Info("stopping RDP auth server")
err := e.rdpServer.Stop()
e.rdpServer = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
// updateRDPServerAuth reuses the SSH authorization config for RDP access control.
// This means the same user/machine-user mappings that control SSH access also control RDP.
func (e *Engine) updateRDPServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil || e.rdpServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping RDP server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.rdpServer.UpdateRDPAuth(authConfig)
}

View File

@@ -64,6 +64,7 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerRDPAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -114,6 +115,7 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerRDPAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -415,6 +417,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerRDPAllowed != nil {
if config.ServerRDPAllowed == nil || *input.ServerRDPAllowed != *config.ServerRDPAllowed {
if *input.ServerRDPAllowed {
log.Infof("enabling RDP passthrough")
} else {
log.Infof("disabling RDP passthrough")
}
config.ServerRDPAllowed = input.ServerRDPAllowed
updated = true
}
} else if config.ServerRDPAllowed == nil {
config.ServerRDPAllowed = util.False()
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")

View File

@@ -472,6 +472,7 @@ type LoginRequest struct {
EnableSSHRemotePortForwarding *bool `protobuf:"varint,37,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,38,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL *int32 `protobuf:"varint,39,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed *bool `protobuf:"varint,40,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -780,6 +781,13 @@ func (x *LoginRequest) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *LoginRequest) GetServerRDPAllowed() bool {
if x != nil && x.ServerRDPAllowed != nil {
return *x.ServerRDPAllowed
}
return false
}
type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -1312,6 +1320,7 @@ type GetConfigResponse struct {
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed bool `protobuf:"varint,27,opt,name=serverRDPAllowed,proto3" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1528,6 +1537,13 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *GetConfigResponse) GetServerRDPAllowed() bool {
if x != nil {
return x.ServerRDPAllowed
}
return false
}
// PeerState contains the latest state of a peer
type PeerState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -4139,6 +4155,7 @@ type SetConfigRequest struct {
EnableSSHRemotePortForwarding *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL *int32 `protobuf:"varint,34,opt,name=sshJWTCacheTTL,proto3,oneof" json:"sshJWTCacheTTL,omitempty"`
ServerRDPAllowed *bool `protobuf:"varint,35,opt,name=serverRDPAllowed,proto3,oneof" json:"serverRDPAllowed,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4411,6 +4428,13 @@ func (x *SetConfigRequest) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *SetConfigRequest) GetServerRDPAllowed() bool {
if x != nil && x.ServerRDPAllowed != nil {
return *x.ServerRDPAllowed
}
return false
}
type SetConfigResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields

View File

@@ -209,6 +209,8 @@ message LoginRequest {
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool serverRDPAllowed = 40;
}
message LoginResponse {
@@ -316,6 +318,8 @@ message GetConfigResponse {
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
bool serverRDPAllowed = 27;
}
// PeerState contains the latest state of a peer
@@ -677,6 +681,8 @@ message SetConfigRequest {
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool serverRDPAllowed = 35;
}
message SetConfigResponse{}

View File

@@ -0,0 +1,88 @@
package client
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"time"
rdpserver "github.com/netbirdio/netbird/client/rdp/server"
)
const (
// DefaultTimeout is the default timeout for sideband auth requests.
DefaultTimeout = 30 * time.Second
// maxResponseSize is the maximum size of an auth response in bytes.
maxResponseSize = 64 * 1024
)
// Client connects to a target peer's RDP sideband auth server to request access.
type Client struct {
Timeout time.Duration
}
// New creates a new sideband RDP auth client.
func New() *Client {
return &Client{
Timeout: DefaultTimeout,
}
}
// RequestAuth sends an authorization request to the target peer's sideband server
// and returns the response. The addr should be in "host:port" format.
func (c *Client) RequestAuth(ctx context.Context, addr string, req *rdpserver.AuthRequest) (*rdpserver.AuthResponse, error) {
timeout := c.Timeout
if timeout <= 0 {
timeout = DefaultTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("connect to RDP auth server at %s: %w", addr, err)
}
defer func() { _ = conn.Close() }()
deadline, ok := ctx.Deadline()
if ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set connection deadline: %w", err)
}
}
// Send request
reqData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("marshal auth request: %w", err)
}
if _, err := conn.Write(reqData); err != nil {
return nil, fmt.Errorf("send auth request: %w", err)
}
// Signal we're done writing so the server can read the full request
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := tcpConn.CloseWrite(); err != nil {
return nil, fmt.Errorf("close write: %w", err)
}
}
// Read response
respData, err := io.ReadAll(io.LimitReader(conn, maxResponseSize))
if err != nil {
return nil, fmt.Errorf("read auth response: %w", err)
}
var resp rdpserver.AuthResponse
if err := json.Unmarshal(respData, &resp); err != nil {
return nil, fmt.Errorf("unmarshal auth response: %w", err)
}
return &resp, nil
}

View File

@@ -0,0 +1,31 @@
[package]
name = "netbird-credprov"
version = "0.1.0"
edition = "2021"
description = "NetBird RDP Credential Provider for Windows"
license = "BSD-3-Clause"
[lib]
crate-type = ["cdylib"]
[dependencies]
windows = { version = "0.58", features = [
"implement",
"Win32_Foundation",
"Win32_System_Com",
"Win32_UI_Shell",
"Win32_Security",
"Win32_Security_Authentication_Identity",
"Win32_Security_Credentials",
"Win32_System_RemoteDesktop",
"Win32_System_Threading",
] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
uuid = { version = "1", features = ["v4"] }
log = "0.4"
[profile.release]
opt-level = "s"
lto = true
strip = true

View File

@@ -0,0 +1,210 @@
//! ICredentialProviderCredential implementation.
//!
//! Represents a single "NetBird Login" credential tile on the Windows login screen.
//! When selected, it queries the local NetBird agent for pending RDP sessions and
//! performs S4U logon to authenticate the user without a password.
use crate::named_pipe_client::{NamedPipeClient, PipeResponse};
use crate::s4u;
use std::sync::Mutex;
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::Security::Credentials::*;
use windows::Win32::UI::Shell::*;
/// NetBird credential tile that appears on the Windows login screen.
#[implement(ICredentialProviderCredential)]
pub struct NetBirdCredential {
/// The pending session information from the NetBird agent.
session: Mutex<Option<PipeResponse>>,
/// The remote IP address of the connecting peer.
remote_ip: Mutex<String>,
}
impl NetBirdCredential {
pub fn new(remote_ip: String, session: PipeResponse) -> Self {
Self {
session: Mutex::new(Some(session)),
remote_ip: Mutex::new(remote_ip),
}
}
}
impl ICredentialProviderCredential_Impl for NetBirdCredential_Impl {
fn Advise(&self, _pcpce: Option<&ICredentialProviderCredentialEvents>) -> Result<()> {
Ok(())
}
fn UnAdvise(&self) -> Result<()> {
Ok(())
}
fn SetSelected(&self, _pbautologon: *mut BOOL) -> Result<()> {
// Auto-logon when this credential is selected
unsafe {
if !_pbautologon.is_null() {
*_pbautologon = TRUE;
}
}
Ok(())
}
fn SetDeselected(&self) -> Result<()> {
Ok(())
}
fn GetFieldState(
&self,
_dwfieldid: u32,
_pcpfs: *mut CREDENTIAL_PROVIDER_FIELD_STATE,
_pcpfis: *mut CREDENTIAL_PROVIDER_FIELD_INTERACTIVE_STATE,
) -> Result<()> {
// We have a single display-only field showing "NetBird Login"
unsafe {
if !_pcpfs.is_null() {
*_pcpfs = CPFS_DISPLAY_IN_SELECTED_TILE;
}
if !_pcpfis.is_null() {
*_pcpfis = CPFIS_NONE;
}
}
Ok(())
}
fn GetStringValue(&self, _dwfieldid: u32) -> Result<PWSTR> {
let session = self.session.lock().unwrap();
let text = if let Some(ref s) = *session {
format!("NetBird: Logging in as {}", s.os_user)
} else {
"NetBird Login".to_string()
};
let wide: Vec<u16> = text.encode_utf16().chain(std::iter::once(0)).collect();
let ptr = unsafe {
let mem = windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
if mem.is_null() {
return Err(E_OUTOFMEMORY.into());
}
std::ptr::copy_nonoverlapping(wide.as_ptr(), mem, wide.len());
PWSTR(mem)
};
Ok(ptr)
}
fn GetBitmapValue(&self, _dwfieldid: u32) -> Result<HBITMAP> {
Err(E_NOTIMPL.into())
}
fn GetCheckboxValue(&self, _dwfieldid: u32, _pbchecked: *mut BOOL, _ppszlabel: *mut PWSTR) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetSubmitButtonValue(&self, _dwfieldid: u32, _pdwadjacentto: *mut u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetComboBoxValueCount(&self, _dwfieldid: u32, _pcitems: *mut u32, _pdwselecteditem: *mut u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetComboBoxValueAt(&self, _dwfieldid: u32, _dwitem: u32) -> Result<PWSTR> {
Err(E_NOTIMPL.into())
}
fn SetStringValue(&self, _dwfieldid: u32, _psz: &PCWSTR) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn SetCheckboxValue(&self, _dwfieldid: u32, _bchecked: BOOL) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn SetComboBoxSelectedValue(&self, _dwfieldid: u32, _dwselecteditem: u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn CommandLinkClicked(&self, _dwfieldid: u32) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn GetSerialization(
&self,
_pcpgsr: *mut CREDENTIAL_PROVIDER_GET_SERIALIZATION_RESPONSE,
_pcpcs: *mut CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
_ppszoptionalstatustext: *mut PWSTR,
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
) -> Result<()> {
let session = self.session.lock().unwrap();
let session_info = match &*session {
Some(s) => s.clone(),
None => {
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_NOT_FINISHED;
}
return Ok(());
}
};
// Consume the session with the agent
if let Err(e) = NamedPipeClient::consume_session(&session_info.session_id) {
log::error!("Failed to consume RDP session: {}", e);
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
}
return Ok(());
}
// Perform S4U logon
let username = &session_info.os_user;
let domain = if session_info.domain.is_empty() {
"."
} else {
&session_info.domain
};
match s4u::generate_s4u_token(username, domain) {
Ok(_token) => {
// In a full implementation, we would serialize the token into
// CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION format
// (KerbInteractiveLogon or MsV1_0InteractiveLogon structure).
//
// For the POC, we signal success. The actual serialization requires
// building the proper KERB_INTERACTIVE_LOGON or MSV1_0_INTERACTIVE_LOGON
// structure with the token handle, which is complex.
//
// TODO: Build proper credential serialization from S4U token
log::info!(
"S4U logon successful for {}\\{}, session {}",
domain,
username,
session_info.session_id
);
unsafe {
*_pcpgsr = CPGSR_RETURN_CREDENTIAL_FINISHED;
// Note: In production, pcpcs would be filled with the serialized credentials
}
Ok(())
}
Err(e) => {
log::error!("S4U logon failed for {}\\{}: {}", domain, username, e);
unsafe {
*_pcpgsr = CPGSR_NO_CREDENTIAL_FINISHED;
}
Ok(())
}
}
}
fn ReportResult(
&self,
_ntstatus: NTSTATUS,
_ntssubstatus: NTSTATUS,
_ppszoptionalstatustext: *mut PWSTR,
_pcpsioptionalstatusicon: *mut CREDENTIAL_PROVIDER_STATUS_ICON,
) -> Result<()> {
Ok(())
}
}

View File

@@ -0,0 +1,11 @@
use windows::core::GUID;
/// CLSID for the NetBird RDP Credential Provider.
/// Generated UUID: {7B3A8E5F-1C4D-4F8A-B2E6-9D0F3A7C5E1B}
pub const CLSID_NETBIRD_CREDENTIAL_PROVIDER: GUID = GUID::from_u128(
0x7B3A8E5F_1C4D_4F8A_B2E6_9D0F3A7C5E1B,
);
/// Registry path for credential providers.
pub const CREDENTIAL_PROVIDER_REGISTRY_PATH: &str =
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers";

View File

@@ -0,0 +1,309 @@
//! NetBird RDP Credential Provider for Windows.
//!
//! This DLL is a Windows Credential Provider that enables passwordless RDP access
//! to machines running the NetBird agent. It is loaded by Windows' LogonUI.exe
//! via COM when the login screen is displayed.
//!
//! ## How it works
//!
//! 1. The DLL is registered as a Credential Provider in the Windows registry
//! 2. When an RDP session begins, LogonUI loads the DLL
//! 3. The DLL queries the local NetBird agent via named pipe for pending sessions
//! 4. If a pending session exists for the connecting peer, the DLL:
//! - Shows a "NetBird Login" credential tile
//! - Performs S4U logon to create a Windows token without a password
//! - Returns the token to LogonUI for session creation
mod credential;
mod guid;
mod named_pipe_client;
mod provider;
mod s4u;
use guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
use provider::NetBirdCredentialProvider;
use std::sync::atomic::{AtomicU32, Ordering};
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::System::Com::*;
/// DLL reference count for COM lifecycle management.
static DLL_REF_COUNT: AtomicU32 = AtomicU32::new(0);
/// DLL module handle.
static mut DLL_MODULE: HMODULE = HMODULE(std::ptr::null_mut());
/// COM class factory for creating NetBirdCredentialProvider instances.
#[implement(IClassFactory)]
struct NetBirdClassFactory;
impl IClassFactory_Impl for NetBirdClassFactory_Impl {
fn CreateInstance(
&self,
_punkouter: Option<&IUnknown>,
riid: *const GUID,
ppvobject: *mut *mut std::ffi::c_void,
) -> Result<()> {
unsafe {
if !ppvobject.is_null() {
*ppvobject = std::ptr::null_mut();
}
}
if _punkouter.is_some() {
return Err(CLASS_E_NOAGGREGATION.into());
}
let provider = NetBirdCredentialProvider::new();
let unknown: IUnknown = provider.into();
unsafe {
unknown.query(riid, ppvobject).ok()
}
}
fn LockServer(&self, flock: BOOL) -> Result<()> {
if flock.as_bool() {
DLL_REF_COUNT.fetch_add(1, Ordering::SeqCst);
} else {
DLL_REF_COUNT.fetch_sub(1, Ordering::SeqCst);
}
Ok(())
}
}
/// DLL entry point.
#[no_mangle]
extern "system" fn DllMain(hinstance: HMODULE, reason: u32, _reserved: *mut std::ffi::c_void) -> BOOL {
const DLL_PROCESS_ATTACH: u32 = 1;
if reason == DLL_PROCESS_ATTACH {
unsafe {
DLL_MODULE = hinstance;
}
}
TRUE
}
/// COM entry point: returns a class factory for the requested CLSID.
#[no_mangle]
extern "system" fn DllGetClassObject(
rclsid: *const GUID,
riid: *const GUID,
ppv: *mut *mut std::ffi::c_void,
) -> HRESULT {
unsafe {
if ppv.is_null() {
return E_POINTER;
}
*ppv = std::ptr::null_mut();
if *rclsid != CLSID_NETBIRD_CREDENTIAL_PROVIDER {
return CLASS_E_CLASSNOTAVAILABLE;
}
let factory = NetBirdClassFactory;
let unknown: IUnknown = factory.into();
match unknown.query(riid, ppv) {
Ok(()) => S_OK,
Err(e) => e.code(),
}
}
}
/// COM entry point: indicates whether the DLL can be unloaded.
#[no_mangle]
extern "system" fn DllCanUnloadNow() -> HRESULT {
if DLL_REF_COUNT.load(Ordering::SeqCst) == 0 {
S_OK
} else {
S_FALSE
}
}
/// Self-registration: called by regsvr32 to register the credential provider.
#[no_mangle]
extern "system" fn DllRegisterServer() -> HRESULT {
match register_credential_provider(true) {
Ok(()) => S_OK,
Err(_) => E_FAIL,
}
}
/// Self-unregistration: called by regsvr32 /u to unregister the credential provider.
#[no_mangle]
extern "system" fn DllUnregisterServer() -> HRESULT {
match register_credential_provider(false) {
Ok(()) => S_OK,
Err(_) => E_FAIL,
}
}
fn register_credential_provider(register: bool) -> std::result::Result<(), Box<dyn std::error::Error>> {
use windows::Win32::System::Registry::*;
let clsid_str = format!("{{{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}}}",
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data1,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data2,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data3,
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[0],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[1],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[2],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[3],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[4],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[5],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[6],
CLSID_NETBIRD_CREDENTIAL_PROVIDER.data4[7],
);
if register {
// Register under Credential Providers
let cp_key_path = format!(
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
clsid_str
);
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
let mut hkey = HKEY::default();
unsafe {
let result = RegCreateKeyExW(
HKEY_LOCAL_MACHINE,
PCWSTR(cp_key_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create credential provider registry key".into());
}
let value: Vec<u16> = "NetBird RDP Credential Provider"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let _ = RegSetValueExW(
hkey,
PCWSTR::null(),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
value.as_ptr() as *const u8,
value.len() * 2,
)),
);
let _ = RegCloseKey(hkey);
}
// Register CLSID in CLSID hive
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
let clsid_key_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let result = RegCreateKeyExW(
HKEY_CLASSES_ROOT,
PCWSTR(clsid_key_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create CLSID registry key".into());
}
let _ = RegCloseKey(hkey);
// InprocServer32 subkey
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
let result = RegCreateKeyExW(
HKEY_CLASSES_ROOT,
PCWSTR(inproc_wide.as_ptr()),
0,
PCWSTR::null(),
REG_OPTION_NON_VOLATILE,
KEY_WRITE,
None,
&mut hkey,
None,
);
if result.is_err() {
return Err("Failed to create InprocServer32 registry key".into());
}
// Set DLL path
let mut dll_path = [0u16; 260];
let len = windows::Win32::System::LibraryLoader::GetModuleFileNameW(
DLL_MODULE,
&mut dll_path,
);
if len > 0 {
let _ = RegSetValueExW(
hkey,
PCWSTR::null(),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
dll_path.as_ptr() as *const u8,
(len as usize + 1) * 2,
)),
);
}
// Set threading model
let threading: Vec<u16> = "Apartment"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let threading_name: Vec<u16> = "ThreadingModel"
.encode_utf16()
.chain(std::iter::once(0))
.collect();
let _ = RegSetValueExW(
hkey,
PCWSTR(threading_name.as_ptr()),
0,
REG_SZ,
Some(std::slice::from_raw_parts(
threading.as_ptr() as *const u8,
threading.len() * 2,
)),
);
let _ = RegCloseKey(hkey);
}
} else {
// Unregister
let cp_key_path = format!(
r"SOFTWARE\Microsoft\Windows\CurrentVersion\Authentication\Credential Providers\{}",
clsid_str
);
let cp_key_wide: Vec<u16> = cp_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let _ = RegDeleteKeyW(HKEY_LOCAL_MACHINE, PCWSTR(cp_key_wide.as_ptr()));
}
let inproc_path = format!(r"CLSID\{}\InprocServer32", clsid_str);
let inproc_wide: Vec<u16> = inproc_path.encode_utf16().chain(std::iter::once(0)).collect();
let clsid_key_path = format!(r"CLSID\{}", clsid_str);
let clsid_wide: Vec<u16> = clsid_key_path.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(inproc_wide.as_ptr()));
let _ = RegDeleteKeyW(HKEY_CLASSES_ROOT, PCWSTR(clsid_wide.as_ptr()));
}
}
Ok(())
}

View File

@@ -0,0 +1,135 @@
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::time::Duration;
/// Named pipe path for communicating with the NetBird agent.
const PIPE_NAME: &str = r"\\.\pipe\netbird-rdp-auth";
/// Maximum response size from the agent.
const MAX_RESPONSE_SIZE: usize = 4096;
/// Timeout for named pipe operations.
const PIPE_TIMEOUT: Duration = Duration::from_secs(5);
/// Request sent to the NetBird agent via named pipe.
#[derive(Serialize)]
pub struct PipeRequest {
pub action: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub remote_ip: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
}
/// Response received from the NetBird agent via named pipe.
#[derive(Deserialize, Debug, Clone)]
pub struct PipeResponse {
pub found: bool,
#[serde(default)]
pub session_id: String,
#[serde(default)]
pub os_user: String,
#[serde(default)]
pub domain: String,
}
/// Client for communicating with the NetBird agent's named pipe server.
pub struct NamedPipeClient;
impl NamedPipeClient {
/// Query the NetBird agent for a pending RDP session matching the given remote IP.
pub fn query_pending(remote_ip: &str) -> Result<PipeResponse, PipeError> {
let request = PipeRequest {
action: "query_pending".to_string(),
remote_ip: Some(remote_ip.to_string()),
session_id: None,
};
Self::send_request(&request)
}
/// Tell the NetBird agent to consume (mark as used) a pending session.
pub fn consume_session(session_id: &str) -> Result<PipeResponse, PipeError> {
let request = PipeRequest {
action: "consume".to_string(),
remote_ip: None,
session_id: Some(session_id.to_string()),
};
Self::send_request(&request)
}
fn send_request(request: &PipeRequest) -> Result<PipeResponse, PipeError> {
let request_data =
serde_json::to_vec(request).map_err(|e| PipeError::Serialization(e.to_string()))?;
// Open named pipe (CreateFile in Windows)
let mut pipe = Self::open_pipe()?;
// Write request
pipe.write_all(&request_data)
.map_err(|e| PipeError::Write(e.to_string()))?;
// Shutdown write side to signal end of request
// For named pipes on Windows, we rely on the message boundary
pipe.flush()
.map_err(|e| PipeError::Write(e.to_string()))?;
// Read response
let mut response_data = vec![0u8; MAX_RESPONSE_SIZE];
let n = pipe
.read(&mut response_data)
.map_err(|e| PipeError::Read(e.to_string()))?;
let response: PipeResponse = serde_json::from_slice(&response_data[..n])
.map_err(|e| PipeError::Deserialization(e.to_string()))?;
Ok(response)
}
fn open_pipe() -> Result<std::fs::File, PipeError> {
// On Windows, named pipes are opened like files
use std::fs::OpenOptions;
// Try to open the pipe with a brief retry for PIPE_BUSY
for attempt in 0..3 {
match OpenOptions::new().read(true).write(true).open(PIPE_NAME) {
Ok(file) => return Ok(file),
Err(e) => {
if attempt < 2 {
std::thread::sleep(Duration::from_millis(100));
continue;
}
return Err(PipeError::Connect(format!(
"failed to open pipe {}: {}",
PIPE_NAME, e
)));
}
}
}
Err(PipeError::Connect("exhausted pipe connection attempts".to_string()))
}
}
/// Errors that can occur during named pipe communication.
#[derive(Debug)]
pub enum PipeError {
Connect(String),
Write(String),
Read(String),
Serialization(String),
Deserialization(String),
}
impl std::fmt::Display for PipeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PipeError::Connect(e) => write!(f, "pipe connect: {}", e),
PipeError::Write(e) => write!(f, "pipe write: {}", e),
PipeError::Read(e) => write!(f, "pipe read: {}", e),
PipeError::Serialization(e) => write!(f, "pipe serialization: {}", e),
PipeError::Deserialization(e) => write!(f, "pipe deserialization: {}", e),
}
}
}
impl std::error::Error for PipeError {}

View File

@@ -0,0 +1,270 @@
//! ICredentialProvider implementation.
//!
//! This is the main COM object that Windows' LogonUI.exe instantiates.
//! It determines whether to show a "NetBird Login" credential tile based on
//! whether the NetBird agent has a pending RDP session for the connecting peer.
use crate::credential::NetBirdCredential;
use crate::guid::CLSID_NETBIRD_CREDENTIAL_PROVIDER;
use crate::named_pipe_client::NamedPipeClient;
use std::sync::Mutex;
use windows::core::*;
use windows::Win32::Foundation::*;
use windows::Win32::Security::Credentials::*;
use windows::Win32::System::RemoteDesktop::*;
/// The NetBird Credential Provider, loaded by LogonUI.exe via COM.
#[implement(ICredentialProvider)]
pub struct NetBirdCredentialProvider {
/// The credential tile (if a pending session was found).
credential: Mutex<Option<ICredentialProviderCredential>>,
/// Whether this provider is active for the current usage scenario.
active: Mutex<bool>,
}
impl NetBirdCredentialProvider {
pub fn new() -> Self {
Self {
credential: Mutex::new(None),
active: Mutex::new(false),
}
}
}
impl ICredentialProvider_Impl for NetBirdCredentialProvider_Impl {
fn SetUsageScenario(
&self,
cpus: CREDENTIAL_PROVIDER_USAGE_SCENARIO,
_dwflags: u32,
) -> Result<()> {
let mut active = self.active.lock().unwrap();
match cpus {
CPUS_LOGON | CPUS_UNLOCK_WORKSTATION => {
// We activate for RDP logon and unlock scenarios
*active = true;
log::info!("NetBird CP activated for usage scenario {:?}", cpus.0);
Ok(())
}
_ => {
// Don't activate for credui or other scenarios
*active = false;
Err(E_NOTIMPL.into())
}
}
}
fn SetSerialization(
&self,
_pcpcs: *const CREDENTIAL_PROVIDER_CREDENTIAL_SERIALIZATION,
) -> Result<()> {
Err(E_NOTIMPL.into())
}
fn Advise(
&self,
_pcpe: Option<&ICredentialProviderEvents>,
_upadvisecontext: usize,
) -> Result<()> {
Ok(())
}
fn UnAdvise(&self) -> Result<()> {
Ok(())
}
fn GetFieldDescriptorCount(&self) -> Result<u32> {
// We have one field: a large text label showing "NetBird: Logging in as <user>"
Ok(1)
}
fn GetFieldDescriptorAt(
&self,
_dwindex: u32,
_ppcpfd: *mut *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR,
) -> Result<()> {
if _dwindex != 0 {
return Err(E_INVALIDARG.into());
}
let label = "NetBird Login";
let wide: Vec<u16> = label.encode_utf16().chain(std::iter::once(0)).collect();
unsafe {
let desc = windows::Win32::System::Com::CoTaskMemAlloc(
std::mem::size_of::<CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR>(),
) as *mut CREDENTIAL_PROVIDER_FIELD_DESCRIPTOR;
if desc.is_null() {
return Err(E_OUTOFMEMORY.into());
}
let label_mem =
windows::Win32::System::Com::CoTaskMemAlloc(wide.len() * 2) as *mut u16;
if label_mem.is_null() {
windows::Win32::System::Com::CoTaskMemFree(Some(desc as *const _));
return Err(E_OUTOFMEMORY.into());
}
std::ptr::copy_nonoverlapping(wide.as_ptr(), label_mem, wide.len());
(*desc).dwFieldID = 0;
(*desc).cpft = CPFT_LARGE_TEXT;
(*desc).pszLabel = PWSTR(label_mem);
(*desc).guidFieldType = GUID::zeroed();
*_ppcpfd = desc;
}
Ok(())
}
fn GetCredentialCount(
&self,
_pdwcount: *mut u32,
_pdwdefault: *mut u32,
_pbautologinwithdefault: *mut BOOL,
) -> Result<()> {
let active = self.active.lock().unwrap();
if !*active {
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
return Ok(());
}
// Try to get the client IP of the current RDP session
let remote_ip = match get_rdp_client_ip() {
Some(ip) => ip,
None => {
log::debug!("NetBird CP: could not determine RDP client IP");
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
return Ok(());
}
};
// Query the NetBird agent for a pending session
match NamedPipeClient::query_pending(&remote_ip) {
Ok(response) if response.found => {
log::info!(
"NetBird CP: found pending session for {} -> {}",
remote_ip,
response.os_user
);
let cred = NetBirdCredential::new(remote_ip, response);
let icred: ICredentialProviderCredential = cred.into();
let mut credential = self.credential.lock().unwrap();
*credential = Some(icred);
unsafe {
*_pdwcount = 1;
*_pdwdefault = 0;
*_pbautologinwithdefault = TRUE; // auto-logon
}
}
Ok(_) => {
log::debug!("NetBird CP: no pending session for {}", remote_ip);
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
}
Err(e) => {
log::debug!("NetBird CP: pipe query failed: {}", e);
unsafe {
*_pdwcount = 0;
*_pdwdefault = u32::MAX;
*_pbautologinwithdefault = FALSE;
}
}
}
Ok(())
}
fn GetCredentialAt(
&self,
_dwindex: u32,
_ppcpc: *mut Option<ICredentialProviderCredential>,
) -> Result<()> {
if _dwindex != 0 {
return Err(E_INVALIDARG.into());
}
let credential = self.credential.lock().unwrap();
match &*credential {
Some(cred) => {
unsafe {
*_ppcpc = Some(cred.clone());
}
Ok(())
}
None => Err(E_UNEXPECTED.into()),
}
}
}
/// Get the IP address of the remote RDP client for the current session.
fn get_rdp_client_ip() -> Option<String> {
unsafe {
// Get the current session ID
let process_id = windows::Win32::System::Threading::GetCurrentProcessId();
let mut session_id = 0u32;
if !windows::Win32::System::RemoteDesktop::ProcessIdToSessionId(process_id, &mut session_id)
.as_bool()
{
log::debug!("ProcessIdToSessionId failed");
return None;
}
// Query the client address
let mut buffer: *mut WTS_CLIENT_ADDRESS = std::ptr::null_mut();
let mut bytes_returned = 0u32;
let result = WTSQuerySessionInformationW(
WTS_CURRENT_SERVER_HANDLE,
session_id,
WTS_INFO_CLASS(14), // WTSClientAddress
&mut buffer as *mut _ as *mut *mut u16,
&mut bytes_returned,
);
if !result.as_bool() || buffer.is_null() {
log::debug!("WTSQuerySessionInformation(WTSClientAddress) failed");
return None;
}
let client_addr = &*buffer;
let ip = match client_addr.AddressFamily as u32 {
// AF_INET
2 => {
let addr = &client_addr.Address;
Some(format!("{}.{}.{}.{}", addr[2], addr[3], addr[4], addr[5]))
}
// AF_INET6
23 => {
// IPv6 - extract from Address bytes
let addr = &client_addr.Address;
Some(format!(
"{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
addr[2], addr[3], addr[4], addr[5], addr[6], addr[7], addr[8], addr[9],
addr[10], addr[11], addr[12], addr[13], addr[14], addr[15], addr[16], addr[17]
))
}
_ => None,
};
WTSFreeMemory(buffer as *mut std::ffi::c_void);
ip
}
}

View File

@@ -0,0 +1,398 @@
//! S4U (Service for User) authentication for Windows.
//!
//! This module ports the S4U logon logic from the Go implementation at:
//! `client/ssh/server/executor_windows.go:generateS4UUserToken()`
//!
//! It creates Windows logon tokens without requiring a password, using the LSA
//! (Local Security Authority) S4U mechanism. This is the same approach used by
//! OpenSSH for Windows for public key authentication.
use std::ptr;
use windows::core::{PCSTR, PWSTR};
use windows::Win32::Foundation::{HANDLE, LUID, NTSTATUS, PSID};
use windows::Win32::Security::Authentication::Identity::{
LsaDeregisterLogonProcess, LsaFreeReturnBuffer, LsaLogonUser, LsaLookupAuthenticationPackage,
LsaRegisterLogonProcess, KERB_S4U_LOGON, MSV1_0_S4U_LOGON, MSV1_0_S4U_LOGON_FLAG_CHECK_LOGONHOURS,
SECURITY_LOGON_TYPE,
};
use windows::Win32::Security::{
QUOTA_LIMITS, TOKEN_SOURCE,
};
/// Status code for successful LSA operations.
const STATUS_SUCCESS: i32 = 0;
/// Network logon type (used for S4U).
const LOGON32_LOGON_NETWORK: SECURITY_LOGON_TYPE = SECURITY_LOGON_TYPE(3);
/// Kerberos S4U logon message type.
const KERB_S4U_LOGON_TYPE: u32 = 12;
/// MSV1_0 S4U logon message type.
const MSV1_0_S4U_LOGON_TYPE: u32 = 12;
/// Authentication package name for Kerberos.
const KERBEROS_PACKAGE: &str = "Kerberos";
/// Authentication package name for MSV1_0 (local users).
const MSV1_0_PACKAGE: &str = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0";
/// Result of a successful S4U logon.
pub struct S4UToken {
pub handle: HANDLE,
}
impl Drop for S4UToken {
fn drop(&mut self) {
if !self.handle.is_invalid() {
unsafe {
let _ = windows::Win32::Foundation::CloseHandle(self.handle);
}
}
}
}
/// Errors from S4U logon operations.
#[derive(Debug)]
pub enum S4UError {
LsaRegister(NTSTATUS),
LookupPackage(NTSTATUS),
LogonUser(NTSTATUS, i32),
AllocateLuid,
InvalidUsername(String),
Utf16Conversion(String),
}
impl std::fmt::Display for S4UError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
S4UError::LsaRegister(s) => write!(f, "LsaRegisterLogonProcess: 0x{:x}", s.0),
S4UError::LookupPackage(s) => write!(f, "LsaLookupAuthenticationPackage: 0x{:x}", s.0),
S4UError::LogonUser(s, sub) => {
write!(f, "LsaLogonUser S4U: NTSTATUS=0x{:x}, SubStatus=0x{:x}", s.0, sub)
}
S4UError::AllocateLuid => write!(f, "AllocateLocallyUniqueId failed"),
S4UError::InvalidUsername(u) => write!(f, "invalid username: {}", u),
S4UError::Utf16Conversion(s) => write!(f, "UTF-16 conversion: {}", s),
}
}
}
impl std::error::Error for S4UError {}
/// Generate a Windows logon token using S4U authentication.
///
/// This creates a token for the specified user without requiring a password.
/// The calling process must have SeTcbPrivilege (typically SYSTEM).
///
/// # Arguments
/// * `username` - The Windows username (without domain prefix)
/// * `domain` - The domain name ("." for local users)
///
/// # Returns
/// An `S4UToken` containing the Windows logon token handle.
pub fn generate_s4u_token(username: &str, domain: &str) -> Result<S4UToken, S4UError> {
if username.is_empty() {
return Err(S4UError::InvalidUsername("empty username".to_string()));
}
let is_local = is_local_user(domain);
// Initialize LSA connection
let lsa_handle = initialize_lsa_connection()?;
// Lookup authentication package
let auth_package_id = lookup_auth_package(lsa_handle, is_local)?;
// Perform S4U logon
let result = perform_s4u_logon(lsa_handle, auth_package_id, username, domain, is_local);
// Cleanup LSA connection
unsafe {
let _ = LsaDeregisterLogonProcess(lsa_handle);
}
result
}
fn is_local_user(domain: &str) -> bool {
domain.is_empty() || domain == "."
}
fn initialize_lsa_connection() -> Result<HANDLE, S4UError> {
let process_name = "NetBird\0";
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (process_name.len() - 1) as u16,
MaximumLength: process_name.len() as u16,
Buffer: windows::core::PSTR(process_name.as_ptr() as *mut u8),
};
let mut lsa_handle = HANDLE::default();
let mut mode = 0u32;
let status = unsafe {
LsaRegisterLogonProcess(&mut lsa_string, &mut lsa_handle, &mut mode)
};
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LsaRegister(status));
}
Ok(lsa_handle)
}
fn lookup_auth_package(lsa_handle: HANDLE, is_local: bool) -> Result<u32, S4UError> {
let package_name = if is_local { MSV1_0_PACKAGE } else { KERBEROS_PACKAGE };
let package_with_null = format!("{}\0", package_name);
let mut lsa_string = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (package_with_null.len() - 1) as u16,
MaximumLength: package_with_null.len() as u16,
Buffer: windows::core::PSTR(package_with_null.as_ptr() as *mut u8),
};
let mut auth_package_id = 0u32;
let status = unsafe {
LsaLookupAuthenticationPackage(lsa_handle, &mut lsa_string, &mut auth_package_id)
};
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LookupPackage(status));
}
Ok(auth_package_id)
}
fn perform_s4u_logon(
lsa_handle: HANDLE,
auth_package_id: u32,
username: &str,
domain: &str,
is_local: bool,
) -> Result<S4UToken, S4UError> {
// Prepare token source
let mut source_name = [0u8; 8];
let name_bytes = b"netbird";
source_name[..name_bytes.len()].copy_from_slice(name_bytes);
let mut source_id = LUID::default();
let alloc_ok = unsafe {
windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime(
&mut std::mem::zeroed(),
);
// Use a simpler approach - just use the current time as a unique ID
source_id.LowPart = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
source_id.HighPart = std::process::id() as i32;
true
};
if !alloc_ok {
return Err(S4UError::AllocateLuid);
}
let token_source = TOKEN_SOURCE {
SourceName: source_name,
SourceIdentifier: source_id,
};
let origin_name_str = "netbird\0";
let mut origin_name = windows::Win32::Security::Authentication::Identity::LSA_STRING {
Length: (origin_name_str.len() - 1) as u16,
MaximumLength: origin_name_str.len() as u16,
Buffer: windows::core::PSTR(origin_name_str.as_ptr() as *mut u8),
};
// Build the logon info structure
let (logon_info_ptr, logon_info_size) = if is_local {
build_msv1_0_s4u_logon(username)?
} else {
build_kerb_s4u_logon(username, domain)?
};
let mut profile: *mut std::ffi::c_void = ptr::null_mut();
let mut profile_size = 0u32;
let mut logon_id = LUID::default();
let mut token = HANDLE::default();
let mut quotas = QUOTA_LIMITS::default();
let mut sub_status: i32 = 0;
let status = unsafe {
LsaLogonUser(
lsa_handle,
&mut origin_name,
LOGON32_LOGON_NETWORK,
auth_package_id,
logon_info_ptr as *const std::ffi::c_void,
logon_info_size as u32,
None, // local groups
&token_source,
&mut profile,
&mut profile_size,
&mut logon_id,
&mut token,
&mut quotas,
&mut sub_status,
)
};
// Free profile buffer if allocated
if !profile.is_null() {
unsafe {
let _ = LsaFreeReturnBuffer(profile);
}
}
// Free the logon info buffer
unsafe {
let layout = std::alloc::Layout::from_size_align_unchecked(logon_info_size, 8);
std::alloc::dealloc(logon_info_ptr as *mut u8, layout);
}
if status.0 != STATUS_SUCCESS {
return Err(S4UError::LogonUser(status, sub_status));
}
Ok(S4UToken { handle: token })
}
/// Build MSV1_0_S4U_LOGON structure for local users.
fn build_msv1_0_s4u_logon(username: &str) -> Result<(*mut u8, usize), S4UError> {
let username_utf16: Vec<u16> = username.encode_utf16().chain(std::iter::once(0)).collect();
let domain_utf16: Vec<u16> = ".".encode_utf16().chain(std::iter::once(0)).collect();
let username_byte_size = username_utf16.len() * 2;
let domain_byte_size = domain_utf16.len() * 2;
// MSV1_0_S4U_LOGON structure:
// MessageType: u32 (4 bytes)
// Flags: u32 (4 bytes)
// UserPrincipalName: UNICODE_STRING (8 bytes on 32-bit, 16 bytes on 64-bit)
// DomainName: UNICODE_STRING
let struct_size = std::mem::size_of::<MSV1_0_S4U_LOGON_HEADER>();
let total_size = struct_size + username_byte_size + domain_byte_size;
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
if buffer.is_null() {
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
}
// For the POC, we'll set up the raw bytes manually since the windows-rs
// MSV1_0_S4U_LOGON structure layout may differ.
// This is a simplified version - in production, use proper FFI bindings.
unsafe {
// MessageType = MSV1_0_S4U_LOGON_TYPE (12)
*(buffer as *mut u32) = MSV1_0_S4U_LOGON_TYPE;
// Flags = 0
*((buffer as *mut u32).add(1)) = 0;
// Copy username UTF-16 after the structure
let username_offset = struct_size;
let username_dest = buffer.add(username_offset);
ptr::copy_nonoverlapping(
username_utf16.as_ptr() as *const u8,
username_dest,
username_byte_size,
);
// Copy domain UTF-16 after username
let domain_offset = username_offset + username_byte_size;
let domain_dest = buffer.add(domain_offset);
ptr::copy_nonoverlapping(
domain_utf16.as_ptr() as *const u8,
domain_dest,
domain_byte_size,
);
// Set UNICODE_STRING for UserPrincipalName (offset 8 on 64-bit)
// Length, MaximumLength, Buffer pointer
let upn_ptr = buffer.add(8) as *mut u16;
*upn_ptr = ((username_utf16.len() - 1) * 2) as u16; // Length (without null)
*(upn_ptr.add(1)) = (username_utf16.len() * 2) as u16; // MaximumLength
*((buffer.add(8 + 4)) as *mut *const u8) = username_dest; // Buffer
// Set UNICODE_STRING for DomainName
let dn_offset = 8 + std::mem::size_of::<UnicodeStringRaw>();
let dn_ptr = buffer.add(dn_offset) as *mut u16;
*dn_ptr = ((domain_utf16.len() - 1) * 2) as u16;
*(dn_ptr.add(1)) = (domain_utf16.len() * 2) as u16;
*((buffer.add(dn_offset + 4)) as *mut *const u8) = domain_dest;
}
Ok((buffer, total_size))
}
/// Build KERB_S4U_LOGON structure for domain users.
fn build_kerb_s4u_logon(username: &str, domain: &str) -> Result<(*mut u8, usize), S4UError> {
// Build UPN: username@domain
let upn = format!("{}@{}", username, domain);
let upn_utf16: Vec<u16> = upn.encode_utf16().chain(std::iter::once(0)).collect();
let upn_byte_size = upn_utf16.len() * 2;
let struct_size = std::mem::size_of::<KerbS4ULogonHeader>();
let total_size = struct_size + upn_byte_size;
let layout = std::alloc::Layout::from_size_align(total_size, 8).unwrap();
let buffer = unsafe { std::alloc::alloc_zeroed(layout) };
if buffer.is_null() {
return Err(S4UError::Utf16Conversion("allocation failed".to_string()));
}
unsafe {
// MessageType = KERB_S4U_LOGON_TYPE (12)
*(buffer as *mut u32) = KERB_S4U_LOGON_TYPE;
// Flags = 0
*((buffer as *mut u32).add(1)) = 0;
// Copy UPN UTF-16 after the structure
let upn_offset = struct_size;
let upn_dest = buffer.add(upn_offset);
ptr::copy_nonoverlapping(
upn_utf16.as_ptr() as *const u8,
upn_dest,
upn_byte_size,
);
// Set UNICODE_STRING for ClientUpn (offset 8)
let upn_str_ptr = buffer.add(8) as *mut u16;
*upn_str_ptr = ((upn_utf16.len() - 1) * 2) as u16;
*(upn_str_ptr.add(1)) = (upn_utf16.len() * 2) as u16;
*((buffer.add(8 + 4)) as *mut *const u8) = upn_dest;
// ClientRealm is empty (zeroed)
}
Ok((buffer, total_size))
}
/// Raw UNICODE_STRING layout for size calculation.
#[repr(C)]
struct UnicodeStringRaw {
_length: u16,
_maximum_length: u16,
_buffer: *const u16,
}
/// Header size for MSV1_0_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
#[repr(C)]
struct MSV1_0_S4U_LOGON_HEADER {
_message_type: u32,
_flags: u32,
_user_principal_name: UnicodeStringRaw,
_domain_name: UnicodeStringRaw,
}
/// Header size for KERB_S4U_LOGON (MessageType + Flags + 2x UNICODE_STRING).
#[repr(C)]
struct KerbS4ULogonHeader {
_message_type: u32,
_flags: u32,
_client_upn: UnicodeStringRaw,
_client_realm: UnicodeStringRaw,
}

21
client/rdp/server/addr.go Normal file
View File

@@ -0,0 +1,21 @@
package server
import (
"fmt"
"net/netip"
)
// parseAddr parses a string into a netip.Addr, stripping any port or zone.
func parseAddr(s string) (netip.Addr, error) {
// Try as plain IP first
if addr, err := netip.ParseAddr(s); err == nil {
return addr, nil
}
// Try as IP:port
if addrPort, err := netip.ParseAddrPort(s); err == nil {
return addrPort.Addr(), nil
}
return netip.Addr{}, fmt.Errorf("invalid IP address: %s", s)
}

View File

@@ -0,0 +1,13 @@
//go:build !windows
package server
// RegisterCredentialProvider is a no-op on non-Windows platforms.
func RegisterCredentialProvider() error {
return nil
}
// UnregisterCredentialProvider is a no-op on non-Windows platforms.
func UnregisterCredentialProvider() error {
return nil
}

View File

@@ -0,0 +1,66 @@
//go:build windows
package server
import (
"fmt"
"os"
"os/exec"
"path/filepath"
log "github.com/sirupsen/logrus"
)
const (
// credProvDLLName is the filename of the credential provider DLL.
credProvDLLName = "netbird_credprov.dll"
)
// RegisterCredentialProvider registers the NetBird Credential Provider COM DLL
// using regsvr32. The DLL must be shipped alongside the NetBird executable.
func RegisterCredentialProvider() error {
dllPath, err := findCredProvDLL()
if err != nil {
return fmt.Errorf("find credential provider DLL: %w", err)
}
cmd := exec.Command("regsvr32", "/s", dllPath)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("regsvr32 %s: %w (output: %s)", dllPath, err, string(output))
}
log.Infof("registered RDP credential provider: %s", dllPath)
return nil
}
// UnregisterCredentialProvider unregisters the NetBird Credential Provider COM DLL.
func UnregisterCredentialProvider() error {
dllPath, err := findCredProvDLL()
if err != nil {
log.Debugf("credential provider DLL not found for unregistration: %v", err)
return nil
}
cmd := exec.Command("regsvr32", "/s", "/u", dllPath)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("regsvr32 /u %s: %w (output: %s)", dllPath, err, string(output))
}
log.Infof("unregistered RDP credential provider: %s", dllPath)
return nil
}
// findCredProvDLL locates the credential provider DLL next to the running executable.
func findCredProvDLL() (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("get executable path: %w", err)
}
dllPath := filepath.Join(filepath.Dir(exePath), credProvDLLName)
if _, err := os.Stat(dllPath); err != nil {
return "", fmt.Errorf("DLL not found at %s: %w", dllPath, err)
}
return dllPath, nil
}

View File

@@ -0,0 +1,184 @@
package server
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net/netip"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
)
const (
// DefaultSessionTTL is the default time-to-live for pending RDP sessions.
DefaultSessionTTL = 60 * time.Second
// cleanupInterval is how often the store checks for expired sessions.
cleanupInterval = 10 * time.Second
// nonceLength is the length of the nonce in bytes.
nonceLength = 32
)
// PendingRDPSession represents an authorized but not yet consumed RDP session.
type PendingRDPSession struct {
SessionID string
PeerIP netip.Addr
OSUsername string
Domain string
JWTUserID string // for audit trail
Nonce string // replay protection
CreatedAt time.Time
ExpiresAt time.Time
consumed bool
}
// PendingStore manages pending RDP session entries with automatic expiration.
type PendingStore struct {
mu sync.RWMutex
sessions map[string]*PendingRDPSession // keyed by SessionID
nonces map[string]struct{} // seen nonces for replay protection
ttl time.Duration
}
// NewPendingStore creates a new pending session store with the given TTL.
func NewPendingStore(ttl time.Duration) *PendingStore {
if ttl <= 0 {
ttl = DefaultSessionTTL
}
return &PendingStore{
sessions: make(map[string]*PendingRDPSession),
nonces: make(map[string]struct{}),
ttl: ttl,
}
}
// Add creates a new pending RDP session and returns it.
func (ps *PendingStore) Add(peerIP netip.Addr, osUsername, domain, jwtUserID, nonce string) (*PendingRDPSession, error) {
ps.mu.Lock()
defer ps.mu.Unlock()
// Check nonce for replay protection
if _, seen := ps.nonces[nonce]; seen {
return nil, fmt.Errorf("duplicate nonce: replay detected")
}
ps.nonces[nonce] = struct{}{}
now := time.Now()
session := &PendingRDPSession{
SessionID: uuid.New().String(),
PeerIP: peerIP,
OSUsername: osUsername,
Domain: domain,
JWTUserID: jwtUserID,
Nonce: nonce,
CreatedAt: now,
ExpiresAt: now.Add(ps.ttl),
}
ps.sessions[session.SessionID] = session
log.Debugf("RDP pending session created: id=%s peer=%s user=%s domain=%s expires=%s",
session.SessionID, peerIP, osUsername, domain, session.ExpiresAt.Format(time.RFC3339))
return session, nil
}
// QueryByPeerIP finds the first non-consumed, non-expired pending session for the given peer IP.
func (ps *PendingStore) QueryByPeerIP(peerIP netip.Addr) (*PendingRDPSession, bool) {
ps.mu.RLock()
defer ps.mu.RUnlock()
now := time.Now()
for _, session := range ps.sessions {
if session.PeerIP == peerIP && !session.consumed && now.Before(session.ExpiresAt) {
return session, true
}
}
return nil, false
}
// Consume marks a session as consumed (single-use). Returns true if the session
// was found and successfully consumed, false if it was already consumed, expired, or not found.
func (ps *PendingStore) Consume(sessionID string) bool {
ps.mu.Lock()
defer ps.mu.Unlock()
session, exists := ps.sessions[sessionID]
if !exists {
return false
}
if session.consumed {
log.Debugf("RDP pending session already consumed: id=%s", sessionID)
return false
}
if time.Now().After(session.ExpiresAt) {
log.Debugf("RDP pending session expired: id=%s", sessionID)
return false
}
session.consumed = true
log.Debugf("RDP pending session consumed: id=%s peer=%s user=%s",
sessionID, session.PeerIP, session.OSUsername)
return true
}
// StartCleanup runs a background goroutine that periodically removes expired sessions.
func (ps *PendingStore) StartCleanup(ctx context.Context) {
go func() {
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
ps.cleanup()
}
}
}()
}
// cleanup removes expired and consumed sessions.
func (ps *PendingStore) cleanup() {
ps.mu.Lock()
defer ps.mu.Unlock()
now := time.Now()
for id, session := range ps.sessions {
if now.After(session.ExpiresAt) || session.consumed {
delete(ps.sessions, id)
delete(ps.nonces, session.Nonce)
}
}
}
// Count returns the number of active (non-expired, non-consumed) sessions.
func (ps *PendingStore) Count() int {
ps.mu.RLock()
defer ps.mu.RUnlock()
count := 0
now := time.Now()
for _, session := range ps.sessions {
if !session.consumed && now.Before(session.ExpiresAt) {
count++
}
}
return count
}
// GenerateNonce creates a cryptographically random nonce for replay protection.
func GenerateNonce() (string, error) {
b := make([]byte, nonceLength)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate nonce: %w", err)
}
return hex.EncodeToString(b), nil
}

View File

@@ -0,0 +1,268 @@
package server
import (
"context"
"net/netip"
"sync"
"testing"
"time"
)
func TestPendingStore_AddAndQuery(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-1")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
if session.SessionID == "" {
t.Fatal("expected non-empty session ID")
}
if session.PeerIP != peerIP {
t.Errorf("expected peer IP %s, got %s", peerIP, session.PeerIP)
}
if session.OSUsername != "admin" {
t.Errorf("expected username admin, got %s", session.OSUsername)
}
// Query should find the session
found, ok := store.QueryByPeerIP(peerIP)
if !ok {
t.Fatal("expected to find pending session")
}
if found.SessionID != session.SessionID {
t.Errorf("expected session %s, got %s", session.SessionID, found.SessionID)
}
// Query for different IP should not find anything
_, ok = store.QueryByPeerIP(netip.MustParseAddr("100.64.0.2"))
if ok {
t.Fatal("expected no session for different IP")
}
}
func TestPendingStore_Consume(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-2")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// First consume should succeed
if !store.Consume(session.SessionID) {
t.Fatal("expected first consume to succeed")
}
// Second consume should fail (already consumed)
if store.Consume(session.SessionID) {
t.Fatal("expected second consume to fail")
}
// Query should no longer find consumed session
_, ok := store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected consumed session to not be found by query")
}
}
func TestPendingStore_Expiry(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
peerIP := netip.MustParseAddr("100.64.0.1")
session, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-3")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// Should be found immediately
_, ok := store.QueryByPeerIP(peerIP)
if !ok {
t.Fatal("expected to find session before expiry")
}
// Wait for expiry
time.Sleep(100 * time.Millisecond)
// Should not be found after expiry
_, ok = store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected session to be expired")
}
// Consume should also fail
if store.Consume(session.SessionID) {
t.Fatal("expected consume of expired session to fail")
}
}
func TestPendingStore_ReplayProtection(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
if err != nil {
t.Fatalf("first Add failed: %v", err)
}
// Same nonce should be rejected
_, err = store.Add(peerIP, "admin", ".", "user@example.com", "nonce-same")
if err == nil {
t.Fatal("expected duplicate nonce to be rejected")
}
}
func TestPendingStore_Cleanup(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-cleanup")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
if store.Count() != 1 {
t.Fatalf("expected count 1, got %d", store.Count())
}
// Wait for expiry then trigger cleanup
time.Sleep(100 * time.Millisecond)
store.cleanup()
if store.Count() != 0 {
t.Fatalf("expected count 0 after cleanup, got %d", store.Count())
}
}
func TestPendingStore_CleanupBackground(t *testing.T) {
store := NewPendingStore(50 * time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
store.StartCleanup(ctx)
peerIP := netip.MustParseAddr("100.64.0.1")
_, err := store.Add(peerIP, "admin", ".", "user@example.com", "nonce-bg-cleanup")
if err != nil {
t.Fatalf("Add failed: %v", err)
}
// Wait for expiry + cleanup interval
time.Sleep(200 * time.Millisecond)
_, ok := store.QueryByPeerIP(peerIP)
if ok {
t.Fatal("expected session to be cleaned up")
}
}
func TestPendingStore_ConcurrentAccess(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
ip := netip.AddrFrom4([4]byte{100, 64, byte(i / 256), byte(i % 256)})
nonce := "nonce-" + string(rune(i+'A'))
if i >= 26 {
nonce = "nonce-" + string(rune(i-26+'a'))
}
session, err := store.Add(ip, "admin", ".", "user", nonce)
if err != nil {
return // nonce collision in test is expected
}
store.QueryByPeerIP(ip)
store.Consume(session.SessionID)
}(i)
}
wg.Wait()
}
func TestPendingStore_MultipleSessions(t *testing.T) {
store := NewPendingStore(DefaultSessionTTL)
ip1 := netip.MustParseAddr("100.64.0.1")
ip2 := netip.MustParseAddr("100.64.0.2")
s1, err := store.Add(ip1, "admin", ".", "user1", "nonce-a")
if err != nil {
t.Fatalf("Add s1 failed: %v", err)
}
s2, err := store.Add(ip2, "jdoe", "DOMAIN", "user2", "nonce-b")
if err != nil {
t.Fatalf("Add s2 failed: %v", err)
}
// Query each
found1, ok := store.QueryByPeerIP(ip1)
if !ok || found1.SessionID != s1.SessionID {
t.Fatal("expected to find s1")
}
found2, ok := store.QueryByPeerIP(ip2)
if !ok || found2.SessionID != s2.SessionID {
t.Fatal("expected to find s2")
}
if found2.Domain != "DOMAIN" {
t.Errorf("expected domain DOMAIN, got %s", found2.Domain)
}
if store.Count() != 2 {
t.Errorf("expected count 2, got %d", store.Count())
}
}
func TestGenerateNonce(t *testing.T) {
nonce1, err := GenerateNonce()
if err != nil {
t.Fatalf("GenerateNonce failed: %v", err)
}
nonce2, err := GenerateNonce()
if err != nil {
t.Fatalf("GenerateNonce failed: %v", err)
}
if len(nonce1) != nonceLength*2 { // hex encoding doubles the length
t.Errorf("expected nonce length %d, got %d", nonceLength*2, len(nonce1))
}
if nonce1 == nonce2 {
t.Error("expected unique nonces")
}
}
func TestParseWindowsUsername(t *testing.T) {
tests := []struct {
input string
expectedUser string
expectedDomain string
}{
{"admin", "admin", "."},
{"DOMAIN\\admin", "admin", "DOMAIN"},
{"admin@domain.com", "admin", "domain.com"},
{".\\localuser", "localuser", "."},
}
for _, tt := range tests {
user, domain := parseWindowsUsername(tt.input)
if user != tt.expectedUser {
t.Errorf("parseWindowsUsername(%q) user = %q, want %q", tt.input, user, tt.expectedUser)
}
if domain != tt.expectedDomain {
t.Errorf("parseWindowsUsername(%q) domain = %q, want %q", tt.input, domain, tt.expectedDomain)
}
}
}

View File

@@ -0,0 +1,19 @@
//go:build !windows
package server
import "context"
type stubPipeServer struct{}
func newPipeServer(_ *PendingStore) PipeServer {
return &stubPipeServer{}
}
func (s *stubPipeServer) Start(_ context.Context) error {
return nil
}
func (s *stubPipeServer) Stop() error {
return nil
}

View File

@@ -0,0 +1,164 @@
//go:build windows
package server
import (
"context"
"encoding/json"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/Microsoft/go-winio"
)
const (
// PipeName is the named pipe path used for IPC between the NetBird agent and
// the Credential Provider DLL.
PipeName = `\\.\pipe\netbird-rdp-auth`
// pipeSDDL restricts access to LOCAL_SYSTEM (SY) and Administrators (BA).
pipeSDDL = "D:P(A;;GA;;;SY)(A;;GA;;;BA)"
// maxPipeRequestSize is the maximum size of a pipe request in bytes.
maxPipeRequestSize = 4096
)
// windowsPipeServer implements the PipeServer interface for Windows.
type windowsPipeServer struct {
pending *PendingStore
listener net.Listener
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
func newPipeServer(pending *PendingStore) PipeServer {
return &windowsPipeServer{
pending: pending,
}
}
func (s *windowsPipeServer) Start(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
s.ctx, s.cancel = context.WithCancel(ctx)
cfg := &winio.PipeConfig{
SecurityDescriptor: pipeSDDL,
}
listener, err := winio.ListenPipe(PipeName, cfg)
if err != nil {
return err
}
s.listener = listener
go s.acceptLoop()
log.Infof("RDP named pipe server started on %s", PipeName)
return nil
}
func (s *windowsPipeServer) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return nil
}
func (s *windowsPipeServer) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
if s.ctx.Err() != nil {
return
}
log.Debugf("RDP pipe accept error: %v", err)
continue
}
go s.handlePipeConnection(conn)
}
}
func (s *windowsPipeServer) handlePipeConnection(conn net.Conn) {
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("RDP pipe close: %v", err)
}
}()
data, err := io.ReadAll(io.LimitReader(conn, maxPipeRequestSize))
if err != nil {
log.Debugf("RDP pipe read: %v", err)
return
}
var req PipeRequest
if err := json.Unmarshal(data, &req); err != nil {
log.Debugf("RDP pipe unmarshal: %v", err)
return
}
var resp PipeResponse
switch req.Action {
case PipeActionQuery:
resp = s.handleQuery(req.RemoteIP)
case PipeActionConsume:
resp = s.handleConsume(req.SessionID)
default:
log.Debugf("RDP pipe unknown action: %s", req.Action)
return
}
respData, err := json.Marshal(resp)
if err != nil {
log.Debugf("RDP pipe marshal response: %v", err)
return
}
if _, err := conn.Write(respData); err != nil {
log.Debugf("RDP pipe write response: %v", err)
}
}
func (s *windowsPipeServer) handleQuery(remoteIP string) PipeResponse {
peerIP, err := parseAddr(remoteIP)
if err != nil {
log.Debugf("RDP pipe invalid remote IP: %s", remoteIP)
return PipeResponse{Found: false}
}
session, found := s.pending.QueryByPeerIP(peerIP)
if !found {
return PipeResponse{Found: false}
}
return PipeResponse{
Found: true,
SessionID: session.SessionID,
OSUser: session.OSUsername,
Domain: session.Domain,
}
}
func (s *windowsPipeServer) handleConsume(sessionID string) PipeResponse {
if s.pending.Consume(sessionID) {
return PipeResponse{Found: true, SessionID: sessionID}
}
return PipeResponse{Found: false}
}

View File

@@ -0,0 +1,48 @@
package server
// AuthRequest is the sideband authorization request sent by the connecting peer
// to the target peer's RDP auth server over the WireGuard tunnel.
type AuthRequest struct {
JWTToken string `json:"jwt_token"`
RequestedUser string `json:"requested_user"`
ClientPeerIP string `json:"client_peer_ip"`
Nonce string `json:"nonce"`
}
// AuthResponse is the sideband authorization response sent by the target peer
// back to the connecting peer.
type AuthResponse struct {
Status string `json:"status"` // "authorized" or "denied"
SessionID string `json:"session_id,omitempty"`
ExpiresAt int64 `json:"expires_at,omitempty"` // unix timestamp
OSUser string `json:"os_user,omitempty"`
Reason string `json:"reason,omitempty"`
}
// PipeRequest is the IPC request from the Credential Provider DLL to the NetBird agent
// via the named pipe.
type PipeRequest struct {
Action string `json:"action"` // "query_pending" or "consume"
RemoteIP string `json:"remote_ip"` // connecting peer's WG IP
SessionID string `json:"session_id,omitempty"` // for consume action
}
// PipeResponse is the IPC response from the NetBird agent to the Credential Provider DLL.
type PipeResponse struct {
Found bool `json:"found"`
SessionID string `json:"session_id,omitempty"`
OSUser string `json:"os_user,omitempty"`
Domain string `json:"domain,omitempty"`
}
const (
// StatusAuthorized indicates the RDP session was authorized.
StatusAuthorized = "authorized"
// StatusDenied indicates the RDP session was denied.
StatusDenied = "denied"
// PipeActionQuery queries for a pending session by remote IP.
PipeActionQuery = "query_pending"
// PipeActionConsume marks a pending session as consumed.
PipeActionConsume = "consume"
)

301
client/rdp/server/server.go Normal file
View File

@@ -0,0 +1,301 @@
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
)
const (
// InternalRDPAuthPort is the port the sideband auth server listens on.
InternalRDPAuthPort = 22338
// DefaultRDPAuthPort is the external port on the WireGuard interface (DNAT target).
DefaultRDPAuthPort = 22338
// maxRequestSize is the maximum size of an auth request in bytes.
maxRequestSize = 64 * 1024
// connectionTimeout is the timeout for a single auth connection.
connectionTimeout = 30 * time.Second
)
// JWTValidator validates JWT tokens and extracts user identity.
type JWTValidator interface {
ValidateAndExtract(token string) (userID string, err error)
}
// Authorizer checks if a user is authorized for RDP access.
type Authorizer interface {
Authorize(jwtUserID, osUsername string) (string, error)
}
// Server is the sideband RDP authorization server that listens on the WireGuard interface.
type Server struct {
listener net.Listener
pending *PendingStore
pipeServer PipeServer
jwtValidator JWTValidator
authorizer Authorizer
sshAuthorizer *sshauth.Authorizer // reuses SSH ACL for RDP access control
networkAddr netip.Prefix // WireGuard network for source IP validation
mu sync.Mutex
ctx context.Context
cancel context.CancelFunc
}
// PipeServer is the interface for the named pipe IPC server (platform-specific).
type PipeServer interface {
Start(ctx context.Context) error
Stop() error
}
// Config holds the configuration for the RDP auth server.
type Config struct {
JWTValidator JWTValidator
Authorizer Authorizer
NetworkAddr netip.Prefix
SessionTTL time.Duration
}
// New creates a new RDP sideband auth server.
func New(cfg *Config) *Server {
ttl := cfg.SessionTTL
if ttl <= 0 {
ttl = DefaultSessionTTL
}
pending := NewPendingStore(ttl)
return &Server{
pending: pending,
pipeServer: newPipeServer(pending),
jwtValidator: cfg.JWTValidator,
authorizer: cfg.Authorizer,
sshAuthorizer: sshauth.NewAuthorizer(),
networkAddr: cfg.NetworkAddr,
}
}
// UpdateRDPAuth updates the RDP authorization config (reuses SSH ACL).
func (s *Server) UpdateRDPAuth(config *sshauth.Config) {
s.sshAuthorizer.Update(config)
log.Debugf("RDP auth: updated authorization config")
}
// Start begins listening for sideband auth requests on the given address.
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return errors.New("RDP auth server already running")
}
s.ctx, s.cancel = context.WithCancel(ctx)
listenAddr := net.TCPAddrFromAddrPort(addr)
listener, err := net.ListenTCP("tcp", listenAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
s.listener = listener
s.pending.StartCleanup(s.ctx)
if s.pipeServer != nil {
if err := s.pipeServer.Start(s.ctx); err != nil {
log.Warnf("failed to start RDP named pipe server: %v", err)
}
}
go s.acceptLoop()
log.Infof("RDP sideband auth server started on %s", addr)
return nil
}
// Stop shuts down the server and cleans up resources.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
}
if s.pipeServer != nil {
if err := s.pipeServer.Stop(); err != nil {
log.Warnf("failed to stop RDP named pipe server: %v", err)
}
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
if err != nil {
return fmt.Errorf("close listener: %w", err)
}
}
log.Info("RDP sideband auth server stopped")
return nil
}
// GetPendingStore returns the pending session store (for testing/named pipe access).
func (s *Server) GetPendingStore() *PendingStore {
return s.pending
}
func (s *Server) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
if s.ctx.Err() != nil {
return
}
log.Debugf("RDP auth accept error: %v", err)
continue
}
go s.handleConnection(conn)
}
}
func (s *Server) handleConnection(conn net.Conn) {
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("RDP auth close connection: %v", err)
}
}()
if err := conn.SetDeadline(time.Now().Add(connectionTimeout)); err != nil {
log.Debugf("RDP auth set deadline: %v", err)
return
}
// Validate source IP is from WireGuard network
remoteAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
if err != nil {
log.Debugf("RDP auth parse remote addr: %v", err)
return
}
if !s.networkAddr.Contains(remoteAddr.Addr()) {
log.Warnf("RDP auth rejected connection from non-WG address: %s", remoteAddr.Addr())
return
}
// Read request
data, err := io.ReadAll(io.LimitReader(conn, maxRequestSize))
if err != nil {
log.Debugf("RDP auth read request: %v", err)
return
}
var req AuthRequest
if err := json.Unmarshal(data, &req); err != nil {
log.Debugf("RDP auth unmarshal request: %v", err)
s.sendResponse(conn, &AuthResponse{Status: StatusDenied, Reason: "invalid request format"})
return
}
response := s.processAuthRequest(remoteAddr.Addr(), &req)
s.sendResponse(conn, response)
}
func (s *Server) processAuthRequest(peerIP netip.Addr, req *AuthRequest) *AuthResponse {
// Validate JWT
if s.jwtValidator == nil {
// No JWT validation configured - for POC, accept all requests from WG peers
log.Warnf("RDP auth: no JWT validator configured, accepting request from %s", peerIP)
return s.createSession(peerIP, req, "no-jwt-validation")
}
userID, err := s.jwtValidator.ValidateAndExtract(req.JWTToken)
if err != nil {
log.Warnf("RDP auth JWT validation failed for %s: %v", peerIP, err)
return &AuthResponse{Status: StatusDenied, Reason: "JWT validation failed"}
}
// Check authorization - try explicit authorizer first, then SSH ACL
if s.authorizer != nil {
if _, err := s.authorizer.Authorize(userID, req.RequestedUser); err != nil {
log.Warnf("RDP auth denied for user %s -> %s: %v", userID, req.RequestedUser, err)
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
}
} else if s.sshAuthorizer != nil {
if _, err := s.sshAuthorizer.Authorize(userID, req.RequestedUser); err != nil {
log.Warnf("RDP auth denied (SSH ACL) for user %s -> %s: %v", userID, req.RequestedUser, err)
return &AuthResponse{Status: StatusDenied, Reason: "not authorized for this user"}
}
}
return s.createSession(peerIP, req, userID)
}
func (s *Server) createSession(peerIP netip.Addr, req *AuthRequest, jwtUserID string) *AuthResponse {
// Parse domain from requested user (DOMAIN\user or user@domain)
osUser, domain := parseWindowsUsername(req.RequestedUser)
session, err := s.pending.Add(peerIP, osUser, domain, jwtUserID, req.Nonce)
if err != nil {
log.Warnf("RDP auth create session failed: %v", err)
return &AuthResponse{Status: StatusDenied, Reason: err.Error()}
}
return &AuthResponse{
Status: StatusAuthorized,
SessionID: session.SessionID,
ExpiresAt: session.ExpiresAt.Unix(),
OSUser: session.OSUsername,
}
}
func (s *Server) sendResponse(conn net.Conn, resp *AuthResponse) {
data, err := json.Marshal(resp)
if err != nil {
log.Debugf("RDP auth marshal response: %v", err)
return
}
if _, err := conn.Write(data); err != nil {
log.Debugf("RDP auth write response: %v", err)
}
}
// parseWindowsUsername extracts username and domain from Windows username formats.
// Supports DOMAIN\username, username@domain, and plain username.
func parseWindowsUsername(fullUsername string) (username, domain string) {
for i := len(fullUsername) - 1; i >= 0; i-- {
if fullUsername[i] == '\\' {
return fullUsername[i+1:], fullUsername[:i]
}
}
if idx := indexOf(fullUsername, '@'); idx != -1 {
return fullUsername[:idx], fullUsername[idx+1:]
}
return fullUsername, "."
}
func indexOf(s string, c byte) int {
for i := 0; i < len(s); i++ {
if s[i] == c {
return i
}
}
return -1
}

View File

@@ -366,6 +366,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerRDPAllowed = msg.ServerRDPAllowed
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -1514,6 +1515,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerRDPAllowed: cfg.ServerRDPAllowed != nil && *cfg.ServerRDPAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,

View File

@@ -58,6 +58,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverRDPAllowed := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -82,6 +83,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerRDPAllowed: &serverRDPAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -125,6 +127,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerRDPAllowed)
require.Equal(t, serverRDPAllowed, *cfg.ServerRDPAllowed)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -176,6 +180,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerRDPAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -236,6 +241,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-rdp": "ServerRDPAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -179,9 +179,11 @@ type StoreConfig struct {
// ReverseProxyConfig contains reverse proxy settings
type ReverseProxyConfig struct {
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
AccessLogRetentionDays int `yaml:"accessLogRetentionDays"`
AccessLogCleanupIntervalHours int `yaml:"accessLogCleanupIntervalHours"`
}
// DefaultConfig returns a CombinedConfig with default values
@@ -645,7 +647,9 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
// Build reverse proxy config
reverseProxy := nbconfig.ReverseProxy{
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
AccessLogRetentionDays: mgmt.ReverseProxy.AccessLogRetentionDays,
AccessLogCleanupIntervalHours: mgmt.ReverseProxy.AccessLogCleanupIntervalHours,
}
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
if prefix, err := netip.ParsePrefix(p); err == nil {

View File

@@ -302,7 +302,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "rate(management_account_peer_meta_update_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])",
"expr": "rate(management_account_peer_meta_update_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])",
"instant": false,
"legendFormat": "{{cluster}}/{{environment}}/{{job}}",
"range": true,
@@ -410,7 +410,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.5,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.5,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"includeNullMetadata": true,
@@ -426,7 +426,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.9,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.9,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"hide": false,
@@ -443,7 +443,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.99,sum(increase(management_account_get_peer_network_map_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.99,sum(increase(management_account_get_peer_network_map_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"hide": false,
@@ -545,7 +545,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.5,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.5,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"includeNullMetadata": true,
@@ -561,7 +561,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.9,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.9,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"hide": false,
@@ -578,7 +578,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.99,sum(increase(management_account_update_account_peers_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.99,sum(increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"hide": false,
@@ -694,7 +694,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.5,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.5,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"includeNullMetadata": true,
@@ -710,7 +710,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.9,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.9,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"hide": false,
@@ -727,7 +727,7 @@
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "histogram_quantile(0.99,sum(increase(management_grpc_updatechannel_queue_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"expr": "histogram_quantile(0.99,sum(increase(management_grpc_updatechannel_queue_length_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le,cluster,environment,job))",
"format": "heatmap",
"fullMetaSearch": false,
"hide": false,
@@ -841,7 +841,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.50, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.50, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"instant": false,
"legendFormat": "p50",
"range": true,
@@ -853,7 +853,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.90, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.90, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p90",
@@ -866,7 +866,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum(rate(management_store_persistence_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.99, sum(rate(management_store_persistence_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p99",
@@ -963,7 +963,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.50, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.50, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"instant": false,
"legendFormat": "p50",
"range": true,
@@ -975,7 +975,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.90, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.90, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p90",
@@ -988,7 +988,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum(rate(management_store_transaction_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.99, sum(rate(management_store_transaction_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p99",
@@ -1085,7 +1085,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.50, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.50, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"instant": false,
"legendFormat": "p50",
"range": true,
@@ -1097,7 +1097,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.90, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.90, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p90",
@@ -1110,7 +1110,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum(rate(management_store_global_lock_acquisition_duration_ms_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"expr": "histogram_quantile(0.99, sum(rate(management_store_global_lock_acquisition_duration_ms_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p99",
@@ -1221,7 +1221,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "rate(management_idp_authenticate_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])",
"expr": "rate(management_idp_authenticate_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])",
"instant": false,
"legendFormat": "{{cluster}}/{{environment}}/{{job}}",
"range": true,
@@ -1317,7 +1317,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "rate(management_idp_get_account_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])",
"expr": "rate(management_idp_get_account_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])",
"instant": false,
"legendFormat": "{{cluster}}/{{environment}}/{{job}}",
"range": true,
@@ -1413,7 +1413,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "rate(management_idp_update_user_meta_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])",
"expr": "rate(management_idp_update_user_meta_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])",
"instant": false,
"legendFormat": "{{cluster}}/{{environment}}/{{job}}",
"range": true,
@@ -1523,7 +1523,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"GET|OPTIONS\"}[$__rate_interval])) by (job,method)",
"expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"GET|OPTIONS\"}[$__rate_interval])) by (job,method)",
"instant": false,
"legendFormat": "{{method}}",
"range": true,
@@ -1619,7 +1619,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"POST|PUT|DELETE\"}[$__rate_interval])) by (job,method)",
"expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",method=~\"POST|PUT|DELETE\"}[$__rate_interval])) by (job,method)",
"instant": false,
"legendFormat": "{{method}}",
"range": true,
@@ -1715,7 +1715,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))",
"expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))",
"instant": false,
"legendFormat": "p50",
"range": true,
@@ -1727,7 +1727,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))",
"expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p90",
@@ -1740,7 +1740,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))",
"expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"read\"}[5m])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p99",
@@ -1837,7 +1837,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))",
"expr": "histogram_quantile(0.50, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))",
"instant": false,
"legendFormat": "p50",
"range": true,
@@ -1849,7 +1849,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))",
"expr": "histogram_quantile(0.90, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p90",
@@ -1862,7 +1862,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))",
"expr": "histogram_quantile(0.99, sum(rate(management_http_request_duration_ms_total_milliseconds_bucket{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\",type=~\"write\"}[5m])) by (le))",
"hide": false,
"instant": false,
"legendFormat": "p99",
@@ -1963,7 +1963,7 @@
"uid": "${datasource}"
},
"editorMode": "code",
"expr": "sum(rate(management_http_request_counter_ratio_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (job,exported_endpoint,method)",
"expr": "sum(rate(management_http_request_counter_total{cluster=~\"$cluster\",environment=~\"$environment\",job=~\"$job\",host=~\"$host\"}[$__rate_interval])) by (job,exported_endpoint,method)",
"hide": false,
"instant": false,
"legendFormat": "{{method}}-{{exported_endpoint}}",
@@ -3222,7 +3222,7 @@
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "sum by(le) (increase(management_grpc_updatechannel_queue_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))",
"expr": "sum by(le) (increase(management_grpc_updatechannel_queue_length_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))",
"format": "heatmap",
"fullMetaSearch": false,
"includeNullMetadata": true,
@@ -3323,7 +3323,7 @@
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "sum by(le) (increase(management_account_update_account_peers_duration_ms_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))",
"expr": "sum by(le) (increase(management_account_update_account_peers_duration_ms_milliseconds_bucket{application=\"management\", environment=\"$environment\", host=~\"$host\"}[$__rate_interval]))",
"format": "heatmap",
"fullMetaSearch": false,
"includeNullMetadata": true,

View File

@@ -106,13 +106,23 @@ func (m *managerImpl) CleanupOldAccessLogs(ctx context.Context, retentionDays in
// StartPeriodicCleanup starts a background goroutine that periodically cleans up old access logs
func (m *managerImpl) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
if retentionDays <= 0 {
log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is 0 or negative")
if retentionDays < 0 {
log.WithContext(ctx).Debug("periodic access log cleanup disabled: retention days is negative")
return
}
if retentionDays == 0 {
retentionDays = 7
log.WithContext(ctx).Debugf("no retention days specified for access log cleanup, defaulting to %d days", retentionDays)
} else {
log.WithContext(ctx).Debugf("access log retention period set to %d days", retentionDays)
}
if cleanupIntervalHours <= 0 {
cleanupIntervalHours = 24
log.WithContext(ctx).Debugf("no cleanup interval specified for access log cleanup, defaulting to %d hours", cleanupIntervalHours)
} else {
log.WithContext(ctx).Debugf("access log cleanup interval set to %d hours", cleanupIntervalHours)
}
cleanupCtx, cancel := context.WithCancel(ctx)

View File

@@ -121,7 +121,7 @@ func TestCleanupWithExactBoundary(t *testing.T) {
}
func TestStartPeriodicCleanup(t *testing.T) {
t.Run("periodic cleanup disabled with zero retention", func(t *testing.T) {
t.Run("periodic cleanup disabled with negative retention", func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
@@ -135,7 +135,7 @@ func TestStartPeriodicCleanup(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
manager.StartPeriodicCleanup(ctx, 0, 1)
manager.StartPeriodicCleanup(ctx, -1, 1)
time.Sleep(100 * time.Millisecond)

View File

@@ -30,3 +30,8 @@ func (d *Domain) EventMeta() map[string]any {
"validated": d.Validated,
}
}
func (d *Domain) Copy() *Domain {
dCopy := *d
return &dCopy
}

View File

@@ -203,7 +203,7 @@ type ReverseProxy struct {
// AccessLogRetentionDays specifies the number of days to retain access logs.
// Logs older than this duration will be automatically deleted during cleanup.
// A value of 0 or negative means logs are kept indefinitely (no cleanup).
// A value of 0 will default to 7 days. Negative means logs are kept indefinitely (no cleanup).
AccessLogRetentionDays int
// AccessLogCleanupIntervalHours specifies how often (in hours) to run the cleanup routine.

View File

@@ -742,11 +742,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
}
for _, otherUser := range account.Users {
if otherUser.Id == userID {
continue

View File

@@ -15,7 +15,6 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -23,6 +22,9 @@ import (
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/shared/management/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -1815,6 +1817,13 @@ func TestAccount_Copy(t *testing.T) {
Targets: []*service.Target{},
},
},
Domains: []*domain.Domain{
{
ID: "domain1",
Domain: "test.com",
AccountID: "account1",
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()

View File

@@ -489,6 +489,102 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
return nil
}
// hasForeignKey checks whether a foreign key constraint exists on the given table and column.
func hasForeignKey(db *gorm.DB, table, column string) bool {
var count int64
switch db.Name() {
case "postgres":
db.Raw(`
SELECT COUNT(*) FROM information_schema.key_column_usage kcu
JOIN information_schema.table_constraints tc
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND kcu.table_name = ?
AND kcu.column_name = ?
`, table, column).Scan(&count)
case "mysql":
db.Raw(`
SELECT COUNT(*) FROM information_schema.key_column_usage
WHERE table_schema = DATABASE()
AND table_name = ?
AND column_name = ?
AND referenced_table_name IS NOT NULL
`, table, column).Scan(&count)
default: // sqlite
type fkInfo struct {
From string
}
var fks []fkInfo
db.Raw(fmt.Sprintf("PRAGMA foreign_key_list(%s)", table)).Scan(&fks)
for _, fk := range fks {
if fk.From == column {
return true
}
}
return false
}
return count > 0
}
// CleanupOrphanedResources deletes rows from the table of model T where the foreign
// key column (fkColumn) references a row in the table of model R that no longer exists.
func CleanupOrphanedResources[T any, R any](ctx context.Context, db *gorm.DB, fkColumn string) error {
var model T
var refModel R
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no cleanup needed", model)
return nil
}
if !db.Migrator().HasTable(&refModel) {
log.WithContext(ctx).Debugf("referenced table for %T does not exist, no cleanup needed", refModel)
return nil
}
stmtT := &gorm.Statement{DB: db}
if err := stmtT.Parse(&model); err != nil {
return fmt.Errorf("parse model %T: %w", model, err)
}
childTable := stmtT.Schema.Table
stmtR := &gorm.Statement{DB: db}
if err := stmtR.Parse(&refModel); err != nil {
return fmt.Errorf("parse reference model %T: %w", refModel, err)
}
parentTable := stmtR.Schema.Table
if !db.Migrator().HasColumn(&model, fkColumn) {
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no cleanup needed", fkColumn, childTable)
return nil
}
// If a foreign key constraint already exists on the column, the DB itself
// enforces referential integrity and orphaned rows cannot exist.
if hasForeignKey(db, childTable, fkColumn) {
log.WithContext(ctx).Debugf("foreign key constraint for %s already exists on %s, no cleanup needed", fkColumn, childTable)
return nil
}
result := db.Exec(
fmt.Sprintf(
"DELETE FROM %s WHERE %s NOT IN (SELECT id FROM %s)",
childTable, fkColumn, parentTable,
),
)
if result.Error != nil {
return fmt.Errorf("cleanup orphaned rows in %s: %w", childTable, result.Error)
}
log.WithContext(ctx).Infof("Cleaned up %d orphaned rows from %s where %s had no matching row in %s",
result.RowsAffected, childTable, fkColumn, parentTable)
return nil
}
func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {
if !db.Migrator().HasTable("peers") {
log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup")

View File

@@ -441,3 +441,197 @@ func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) {
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
require.NoError(t, err, "Should not fail when table does not exist")
}
type testParent struct {
ID string `gorm:"primaryKey"`
}
func (testParent) TableName() string {
return "test_parents"
}
type testChild struct {
ID string `gorm:"primaryKey"`
ParentID string
}
func (testChild) TableName() string {
return "test_children"
}
type testChildWithFK struct {
ID string `gorm:"primaryKey"`
ParentID string `gorm:"index"`
Parent *testParent `gorm:"foreignKey:ParentID"`
}
func (testChildWithFK) TableName() string {
return "test_children"
}
func setupOrphanTestDB(t *testing.T, models ...any) *gorm.DB {
t.Helper()
db := setupDatabase(t)
for _, m := range models {
_ = db.Migrator().DropTable(m)
}
err := db.AutoMigrate(models...)
require.NoError(t, err, "Failed to auto-migrate tables")
return db
}
func TestCleanupOrphanedResources_NoChildTable(t *testing.T) {
db := setupDatabase(t)
_ = db.Migrator().DropTable(&testChild{})
_ = db.Migrator().DropTable(&testParent{})
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err, "Should not fail when child table does not exist")
}
func TestCleanupOrphanedResources_NoParentTable(t *testing.T) {
db := setupDatabase(t)
_ = db.Migrator().DropTable(&testParent{})
_ = db.Migrator().DropTable(&testChild{})
err := db.AutoMigrate(&testChild{})
require.NoError(t, err)
err = migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err, "Should not fail when parent table does not exist")
}
func TestCleanupOrphanedResources_EmptyTables(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err, "Should not fail on empty tables")
var count int64
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(0), count)
}
func TestCleanupOrphanedResources_NoOrphans(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error)
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err)
var count int64
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(2), count, "All children should remain when no orphans")
}
func TestCleanupOrphanedResources_AllOrphans(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c1", "gone1").Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone2").Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c3", "gone3").Error)
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err)
var count int64
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(0), count, "All orphaned children should be deleted")
}
func TestCleanupOrphanedResources_MixedValidAndOrphaned(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c3", ParentID: "p1"}).Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c4", "gone1").Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c5", "gone2").Error)
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err)
var remaining []testChild
require.NoError(t, db.Order("id").Find(&remaining).Error)
assert.Len(t, remaining, 3, "Only valid children should remain")
assert.Equal(t, "c1", remaining[0].ID)
assert.Equal(t, "c2", remaining[1].ID)
assert.Equal(t, "c3", remaining[2].ID)
}
func TestCleanupOrphanedResources_Idempotent(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone").Error)
ctx := context.Background()
err := migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id")
require.NoError(t, err)
var count int64
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(1), count)
err = migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id")
require.NoError(t, err)
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(1), count, "Count should remain the same after second run")
}
func TestCleanupOrphanedResources_SkipsWhenForeignKeyExists(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine != "postgres" && engine != "mysql" {
t.Skip("FK constraint early-exit test requires postgres or mysql")
}
db := setupDatabase(t)
_ = db.Migrator().DropTable(&testChildWithFK{})
_ = db.Migrator().DropTable(&testParent{})
err := db.AutoMigrate(&testParent{}, &testChildWithFK{})
require.NoError(t, err)
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
require.NoError(t, db.Create(&testChildWithFK{ID: "c1", ParentID: "p1"}).Error)
require.NoError(t, db.Create(&testChildWithFK{ID: "c2", ParentID: "p2"}).Error)
switch engine {
case "postgres":
require.NoError(t, db.Exec("ALTER TABLE test_children DROP CONSTRAINT fk_test_children_parent").Error)
require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error)
require.NoError(t, db.Exec(
"ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+
"FOREIGN KEY (parent_id) REFERENCES test_parents(id) NOT VALID",
).Error)
case "mysql":
require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error)
require.NoError(t, db.Exec("ALTER TABLE test_children DROP FOREIGN KEY fk_test_children_parent").Error)
require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error)
require.NoError(t, db.Exec(
"ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+
"FOREIGN KEY (parent_id) REFERENCES test_parents(id)",
).Error)
require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error)
}
err = migration.CleanupOrphanedResources[testChildWithFK, testParent](context.Background(), db, "parent_id")
require.NoError(t, err)
var count int64
db.Model(&testChildWithFK{}).Count(&count)
assert.Equal(t, int64(2), count, "Both rows should survive — migration must skip when FK constraint exists")
}

View File

@@ -396,6 +396,11 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.Services, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
@@ -2099,7 +2104,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var mode, source, sourcePeer sql.NullString
var terminated sql.NullBool
var terminated, portAutoAssigned sql.NullBool
var listenPort sql.NullInt64
err := row.Scan(
&s.ID,
&s.AccountID,
@@ -2116,8 +2122,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
&sessionPrivateKey,
&sessionPublicKey,
&mode,
&s.ListenPort,
&s.PortAutoAssigned,
&listenPort,
&portAutoAssigned,
&source,
&sourcePeer,
&terminated,
@@ -2164,6 +2170,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
if terminated.Valid {
s.Terminated = terminated.Bool
}
if portAutoAssigned.Valid {
s.PortAutoAssigned = portAutoAssigned.Bool
}
if listenPort.Valid {
s.ListenPort = uint16(listenPort.Int64)
}
s.Targets = []*rpservice.Target{}
return &s, nil
})

View File

@@ -22,6 +22,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -350,6 +352,35 @@ func TestSqlite_DeleteAccount(t *testing.T) {
},
}
account.Services = []*rpservice.Service{
{
ID: "service_id",
AccountID: account.Id,
Name: "test service",
Domain: "svc.example.com",
Enabled: true,
Targets: []*rpservice.Target{
{
AccountID: account.Id,
ServiceID: "service_id",
Host: "localhost",
Port: 8080,
Protocol: "http",
Enabled: true,
},
},
},
}
account.Domains = []*proxydomain.Domain{
{
ID: "domain_id",
Domain: "custom.example.com",
AccountID: account.Id,
Validated: true,
},
}
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
@@ -411,6 +442,20 @@ func TestSqlite_DeleteAccount(t *testing.T) {
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources")
require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount")
}
domains, err := store.ListCustomDomains(context.Background(), account.Id)
require.NoError(t, err, "expecting no error after DeleteAccount when searching for custom domains")
require.Len(t, domains, 0, "expecting no custom domains to be found after DeleteAccount")
var services []*rpservice.Service
err = store.(*SqlStore).db.Model(&rpservice.Service{}).Find(&services, "account_id = ?", account.Id).Error
require.NoError(t, err, "expecting no error after DeleteAccount when searching for services")
require.Len(t, services, 0, "expecting no services to be found after DeleteAccount")
var targets []*rpservice.Target
err = store.(*SqlStore).db.Model(&rpservice.Target{}).Find(&targets, "account_id = ?", account.Id).Error
require.NoError(t, err, "expecting no error after DeleteAccount when searching for service targets")
require.Len(t, targets, 0, "expecting no service targets to be found after DeleteAccount")
}
func Test_GetAccount(t *testing.T) {

View File

@@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -265,6 +266,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
&domain.Domain{},
}
for i := len(models) - 1; i >= 0; i-- {

View File

@@ -448,6 +448,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.RemoveDuplicatePeerKeys(ctx, db)
},
func(db *gorm.DB) error {
return migration.CleanupOrphanedResources[rpservice.Service, types.Account](ctx, db, "account_id")
},
func(db *gorm.DB) error {
return migration.CleanupOrphanedResources[domain.Domain, types.Account](ctx, db, "account_id")
},
}
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -101,6 +102,7 @@ type Account struct {
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
Domains []*proxydomain.Domain `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -911,6 +913,11 @@ func (a *Account) Copy() *Account {
services = append(services, svc.Copy())
}
domains := []*proxydomain.Domain{}
for _, domain := range a.Domains {
domains = append(domains, domain.Copy())
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
@@ -936,6 +943,7 @@ func (a *Account) Copy() *Account {
Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
Domains: domains,
}
}

View File

@@ -333,7 +333,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
dialers := c.getDialers()
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
conn, err := rd.Dial()
conn, err := rd.Dial(ctx)
if err != nil {
return nil, err
}

View File

@@ -40,10 +40,10 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri
}
}
func (r *RaceDial) Dial() (net.Conn, error) {
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(context.Background())
abortCtx, abort := context.WithCancel(ctx)
defer abort()
for _, dfn := range r.dialerFns {

View File

@@ -78,7 +78,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
serverURL := "test.server.com"
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
conn, err := rd.Dial()
conn, err := rd.Dial(context.Background())
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
}
@@ -104,7 +104,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
conn, err := rd.Dial()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
@@ -137,7 +137,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
@@ -160,7 +160,7 @@ func TestRaceDialTimeout(t *testing.T) {
}
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
conn, err := rd.Dial()
conn, err := rd.Dial(context.Background())
if err == nil {
t.Errorf("Expected an error, got nil")
}
@@ -188,7 +188,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
conn, err := rd.Dial(context.Background())
if err == nil {
t.Errorf("Expected an error, got nil")
}
@@ -230,7 +230,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
conn, err := rd.Dial(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}