Allow signing downloaded RDP file (#156)

Implement signing of RDP files downloaded from web
This commit is contained in:
Andrew Heberle
2025-09-05 20:21:32 +08:00
committed by GitHub
parent 10722d7105
commit 2b9ec4a3f0
10 changed files with 259 additions and 34 deletions

View File

@@ -328,6 +328,12 @@ Client:
SplitUserDomain: false SplitUserDomain: false
# If true, removes "username" (and "domain" if SplitUserDomain is true) from RDP file. # If true, removes "username" (and "domain" if SplitUserDomain is true) from RDP file.
# NoUsername: true # NoUsername: true
# If both SigningCert and SigningKey are set the downloaded RDP file will be signed
# so the client can authenticate the validity of the RDP file and reduce warnings from
# the client if the CA that issued the certificate is trusted. Both should be PEM encoded
# and the key must be an unencrypted RSA private key.
# SigningCert: /path/to/signing.crt
# SigningKey: /path/to/signing.key
Security: Security:
# a random string of 32 characters to secure cookies on the client # a random string of 32 characters to secure cookies on the client
# make sure to share this amongst different pods # make sure to share this amongst different pods

View File

@@ -1,15 +1,16 @@
package config package config
import ( import (
"log"
"os"
"strings"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/confmap" "github.com/knadh/koanf/providers/confmap"
"github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/providers/file" "github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/v2" "github.com/knadh/koanf/v2"
"log"
"os"
"strings"
) )
const ( const (
@@ -96,6 +97,8 @@ type ClientConfig struct {
UsernameTemplate string `koanf:"usernametemplate"` UsernameTemplate string `koanf:"usernametemplate"`
SplitUserDomain bool `koanf:"splituserdomain"` SplitUserDomain bool `koanf:"splituserdomain"`
NoUsername bool `koanf:"nousername"` NoUsername bool `koanf:"nousername"`
SigningCert string `koanf:"signingcert"`
SigningKey string `koanf:"signingkey"`
} }
func ToCamel(s string) string { func ToCamel(s string) string {
@@ -219,10 +222,10 @@ func Load(configFile string) Configuration {
if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" { if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" {
log.Fatalf("basicauth=local and tls=disable are mutually exclusive") log.Fatalf("basicauth=local and tls=disable are mutually exclusive")
} }
if Conf.Server.NtlmEnabled() && Conf.Server.KerberosEnabled() { if Conf.Server.NtlmEnabled() && Conf.Server.KerberosEnabled() {
log.Fatalf("ntlm and kerberos authentication are not stackable") log.Fatalf("ntlm and kerberos authentication are not stackable")
} }
if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() { if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() {
log.Fatalf("openid is configured but tokenauth disabled") log.Fatalf("openid is configured but tokenauth disabled")
@@ -238,7 +241,6 @@ func Load(configFile string) Configuration {
} }
return Conf return Conf
} }
func (s *ServerConfig) OpenIDEnabled() bool { func (s *ServerConfig) OpenIDEnabled() bool {

View File

@@ -4,6 +4,12 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log"
"net/http"
"net/url"
"os"
"strconv"
"github.com/bolkedebruin/gokrb5/v8/keytab" "github.com/bolkedebruin/gokrb5/v8/keytab"
"github.com/bolkedebruin/gokrb5/v8/service" "github.com/bolkedebruin/gokrb5/v8/service"
"github.com/bolkedebruin/gokrb5/v8/spnego" "github.com/bolkedebruin/gokrb5/v8/spnego"
@@ -18,11 +24,6 @@ import (
"github.com/thought-machine/go-flags" "github.com/thought-machine/go-flags"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"log"
"net/http"
"net/url"
"os"
"strconv"
) )
const ( const (
@@ -110,10 +111,12 @@ func main() {
RdpOpts: web.RdpOpts{ RdpOpts: web.RdpOpts{
UsernameTemplate: conf.Client.UsernameTemplate, UsernameTemplate: conf.Client.UsernameTemplate,
SplitUserDomain: conf.Client.SplitUserDomain, SplitUserDomain: conf.Client.SplitUserDomain,
NoUsername: conf.Client.NoUsername, NoUsername: conf.Client.NoUsername,
}, },
GatewayAddress: url, GatewayAddress: url,
TemplateFile: conf.Client.Defaults, TemplateFile: conf.Client.Defaults,
RdpSigningCert: conf.Client.SigningCert,
RdpSigningKey: conf.Client.SigningKey,
} }
if conf.Caps.TokenAuth { if conf.Caps.TokenAuth {
@@ -229,7 +232,7 @@ func main() {
// for stacking of authentication // for stacking of authentication
auth := web.NewAuthMux() auth := web.NewAuthMux()
rdp.MatcherFunc(web.NoAuthz).HandlerFunc(auth.SetAuthenticate) rdp.MatcherFunc(web.NoAuthz).HandlerFunc(auth.SetAuthenticate)
// ntlm // ntlm
if conf.Server.NtlmEnabled() { if conf.Server.NtlmEnabled() {
log.Printf("enabling NTLM authentication") log.Printf("enabling NTLM authentication")
@@ -238,7 +241,7 @@ func main() {
rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol)) rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol))
auth.Register(`NTLM`) auth.Register(`NTLM`)
auth.Register(`Negotiate`) auth.Register(`Negotiate`)
} }
// basic auth // basic auth
if conf.Server.BasicAuthEnabled() { if conf.Server.BasicAuthEnabled() {

View File

@@ -3,17 +3,18 @@ package protocol
import ( import (
"context" "context"
"errors" "errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
"log" "log"
"net" "net"
"net/http" "net/http"
"reflect" "reflect"
"syscall" "syscall"
"time" "time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
) )
const ( const (
@@ -140,7 +141,7 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
if !ptrSysFd.IsValid() { if !ptrSysFd.IsValid() {
return errors.New("cannot find Sysfd field") return errors.New("cannot find Sysfd field")
} }
fd := int(ptrSysFd.Int()) fd := int64ToFd(ptrSysFd.Int())
if g.ReceiveBuf > 0 { if g.ReceiveBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf) err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf)

View File

@@ -0,0 +1,8 @@
//go:build !windows
package protocol
// the fd arg to syscall.SetsockoptInt on Linix is of type int
func int64ToFd(n int64) int {
return int(n)
}

View File

@@ -0,0 +1,10 @@
package protocol
import (
"syscall"
)
// the fd arg to syscall.SetsockoptInt on Windows is of type syscall.Handle
func int64ToFd(n int64) syscall.Handle {
return syscall.Handle(n)
}

View File

@@ -4,12 +4,13 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"net/http"
"time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"net/http"
"time"
) )
const ( const (
@@ -91,7 +92,7 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
id.SetAuthTime(time.Now()) id.SetAuthTime(time.Now())
id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken) id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken)
if err = SaveSessionIdentity(r, w, id); err != nil { if err := SaveSessionIdentity(r, w, id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }

View File

@@ -1,13 +1,12 @@
package web package web
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
"hash/maphash" "hash/maphash"
"log" "log"
rnd "math/rand" rnd "math/rand"
@@ -15,6 +14,10 @@ import (
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
) )
type TokenGeneratorFunc func(context.Context, string, string) (string, error) type TokenGeneratorFunc func(context.Context, string, string) (string, error)
@@ -32,6 +35,8 @@ type Config struct {
GatewayAddress *url.URL GatewayAddress *url.URL
RdpOpts RdpOpts RdpOpts RdpOpts
TemplateFile string TemplateFile string
RdpSigningCert string
RdpSigningKey string
} }
type RdpOpts struct { type RdpOpts struct {
@@ -51,6 +56,7 @@ type Handler struct {
hostSelection string hostSelection string
rdpOpts RdpOpts rdpOpts RdpOpts
rdpDefaults string rdpDefaults string
rdpSigner *rdpsign.Signer
} }
func (c *Config) NewHandler() *Handler { func (c *Config) NewHandler() *Handler {
@@ -58,7 +64,7 @@ func (c *Config) NewHandler() *Handler {
log.Fatal("Not enough hosts to connect to specified") log.Fatal("Not enough hosts to connect to specified")
} }
return &Handler{ handler := &Handler{
paaTokenGenerator: c.PAATokenGenerator, paaTokenGenerator: c.PAATokenGenerator,
enableUserToken: c.EnableUserToken, enableUserToken: c.EnableUserToken,
userTokenGenerator: c.UserTokenGenerator, userTokenGenerator: c.UserTokenGenerator,
@@ -70,6 +76,18 @@ func (c *Config) NewHandler() *Handler {
rdpOpts: c.RdpOpts, rdpOpts: c.RdpOpts,
rdpDefaults: c.TemplateFile, rdpDefaults: c.TemplateFile,
} }
// set up RDP signer if config values are set
if c.RdpSigningCert != "" && c.RdpSigningKey != "" {
signer, err := rdpsign.New(c.RdpSigningCert, c.RdpSigningKey)
if err != nil {
log.Fatal("Could not set up RDP signer", err)
}
handler.rdpSigner = signer
}
return handler
} }
func (h *Handler) selectRandomHost() string { func (h *Handler) selectRandomHost() string {
@@ -160,7 +178,7 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
render := user render := user
if opts.UsernameTemplate != "" { if opts.UsernameTemplate != "" {
render = fmt.Sprintf(h.rdpOpts.UsernameTemplate) render = fmt.Sprint(h.rdpOpts.UsernameTemplate)
render = strings.Replace(render, "{{ username }}", user, 1) render = strings.Replace(render, "{{ username }}", user, 1)
if h.rdpOpts.UsernameTemplate == render { if h.rdpOpts.UsernameTemplate == render {
log.Printf("Invalid username template. %s == %s", h.rdpOpts.UsernameTemplate, user) log.Printf("Invalid username template. %s == %s", h.rdpOpts.UsernameTemplate, user)
@@ -224,5 +242,23 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
d.Settings.GatewayCredentialMethod = 1 d.Settings.GatewayCredentialMethod = 1
d.Settings.GatewayUsageMethod = 1 d.Settings.GatewayUsageMethod = 1
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(d.String())) // no rdp siging so return as-is
if h.rdpSigner == nil {
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(d.String()))
return
}
// get rdp content
rdpContent := d.String()
// sign rdp content
signedContent, err := h.rdpSigner.Sign(rdpContent)
if err != nil {
log.Printf("Could not sign RDP file due to %s", err)
http.Error(w, errors.New("could not sign RDP file").Error(), http.StatusInternalServerError)
return
}
// return signd rdp file
http.ServeContent(w, r, fn, time.Now(), bytes.NewReader(signedContent))
} }

View File

@@ -2,15 +2,25 @@ package web
import ( import (
"context" "context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "crypto/rand"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp" "crypto/rsa"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" "crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os" "os"
"strings" "strings"
"testing" "testing"
"time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/spf13/afero"
) )
const ( const (
@@ -172,6 +182,89 @@ func TestHandler_HandleDownload(t *testing.T) {
} }
func TestHandler_HandleSignedDownload(t *testing.T) {
req, err := http.NewRequest("GET", "/connect", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
id := identity.NewUser()
id.SetUserName(testuser)
id.SetAuthenticated(true)
req = identity.AddToRequestCtx(id, req)
ctx := req.Context()
u, _ := url.Parse(gateway)
c := Config{
HostSelection: "roundrobin",
Hosts: hosts,
PAATokenGenerator: paaTokenMock,
GatewayAddress: u,
RdpOpts: RdpOpts{SplitUserDomain: true},
}
h := c.NewHandler()
// set up rdp signer
fs := afero.NewMemMapFs()
if err := genKeypair(fs); err != nil {
t.Errorf("could not generate key pair for testing: %s", err)
}
signer, err := rdpsign.New("test.crt", "test.key", rdpsign.WithFs(fs))
if err != nil {
t.Errorf("could not create *rdpsign.Signer for testing: %s", err)
}
h.rdpSigner = signer
hh := http.HandlerFunc(h.HandleDownload)
hh.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
if ctype := rr.Header().Get("Content-Type"); ctype != "application/x-rdp" {
t.Errorf("content type header does not match: got %v want %v",
ctype, "application/json")
}
if cdisp := rr.Header().Get("Content-Disposition"); cdisp == "" {
t.Errorf("content disposition is nil")
}
data := rdpToMap(strings.Split(rr.Body.String(), rdp.CRLF))
if data["username"] != testuser {
t.Errorf("username key in rdp does not match: got %v want %v", data["username"], testuser)
}
if data["gatewayhostname"] != u.Host {
t.Errorf("gatewayhostname key in rdp does not match: got %v want %v", data["gatewayhostname"], u.Host)
}
if token, _ := paaTokenMock(ctx, testuser, data["full address"]); token != data["gatewayaccesstoken"] {
t.Errorf("gatewayaccesstoken key in rdp does not match username_full address: got %v want %v",
data["gatewayaccesstoken"], token)
}
if !contains(data["full address"], hosts) {
t.Errorf("full address key in rdp is not in allowed hosts list: go %v want in %v",
data["full address"], hosts)
}
signscopeWant := "GatewayHostname,Full Address,GatewayCredentialsSource,GatewayProfileUsageMethod,GatewayUsageMethod,Alternate Full Address"
if data["signscope"] != signscopeWant {
t.Errorf("signscope key in rdp does not match: got %v want %v", data["signscope"], signscopeWant)
}
if _, found := data["signature"]; !found {
t.Errorf("no signature found in rdp")
}
}
func TestHandler_HandleDownloadWithRdpTemplate(t *testing.T) { func TestHandler_HandleDownloadWithRdpTemplate(t *testing.T) {
f, err := os.CreateTemp("", "rdp") f, err := os.CreateTemp("", "rdp")
if err != nil { if err != nil {
@@ -233,3 +326,68 @@ func rdpToMap(rdp []string) map[string]string {
return ret return ret
} }
func genKeypair(fs afero.Fs) error {
// generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
// convert to DER
der, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return err
}
// encode DER private key as PEM
if err := func() error {
f, err := fs.Create("test.key")
if err != nil {
return err
}
defer f.Close()
return pem.Encode(f, &pem.Block{
Type: "PRIVATE KEY",
Bytes: der,
})
}(); err != nil {
return err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Example Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute * 10),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return err
}
// encode cert as PEM
if err := func() error {
f, err := fs.Create("test.crt")
if err != nil {
return err
}
defer f.Close()
return pem.Encode(f, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
}(); err != nil {
return err
}
return nil
}

6
go.mod
View File

@@ -1,10 +1,9 @@
module github.com/bolkedebruin/rdpgw module github.com/bolkedebruin/rdpgw
go 1.23.0 go 1.24.2
toolchain go1.24.1
require ( require (
github.com/andrewheberle/rdpsign v1.1.0
github.com/bolkedebruin/gokrb5/v8 v8.5.0 github.com/bolkedebruin/gokrb5/v8 v8.5.0
github.com/coreos/go-oidc/v3 v3.9.0 github.com/coreos/go-oidc/v3 v3.9.0
github.com/fatih/structs v1.1.0 github.com/fatih/structs v1.1.0
@@ -25,6 +24,7 @@ require (
github.com/msteinert/pam/v2 v2.0.0 github.com/msteinert/pam/v2 v2.0.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.19.0 github.com/prometheus/client_golang v1.19.0
github.com/spf13/afero v1.14.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
github.com/thought-machine/go-flags v1.6.3 github.com/thought-machine/go-flags v1.6.3
golang.org/x/crypto v0.36.0 golang.org/x/crypto v0.36.0