Add Makefile build

This prepares for multiple binaries and distribution builds
This commit is contained in:
Bolke de Bruin
2021-05-17 09:53:30 +02:00
parent bb2501c7a6
commit bf362b4e52
17 changed files with 90 additions and 22 deletions

40
cmd/rdpgw/api/token.go Normal file
View File

@@ -0,0 +1,40 @@
package api
import (
"context"
"encoding/json"
"fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"log"
"net/http"
)
func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
return
}
tokens, ok := r.URL.Query()["access_token"]
if !ok || len(tokens[0]) < 1 {
log.Printf("Missing access_token in request")
http.Error(w, "access_token missing in request", http.StatusBadRequest)
return
}
token := tokens[0]
info, err := security.UserInfo(context.Background(), token)
if err != nil {
log.Printf("Token validation failed due to %s", err)
http.Error(w, fmt.Sprintf("token validation failed due to %s", err), http.StatusForbidden)
return
}
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
if err = json.NewEncoder(w).Encode(info); err != nil {
log.Printf("Cannot encode json due to %s", err)
http.Error(w, "cannot encode json", http.StatusInternalServerError)
return
}
}

220
cmd/rdpgw/api/web.go Normal file
View File

@@ -0,0 +1,220 @@
package api
import (
"context"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"log"
"math/rand"
"net/http"
"strconv"
"strings"
"time"
)
const (
RdpGwSession = "RDPGWSESSION"
MaxAge = 120
)
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
type Config struct {
SessionKey []byte
SessionEncryptionKey []byte
PAATokenGenerator TokenGeneratorFunc
UserTokenGenerator UserTokenGeneratorFunc
EnableUserToken bool
OAuth2Config *oauth2.Config
store *sessions.CookieStore
OIDCTokenVerifier *oidc.IDTokenVerifier
stateStore *cache.Cache
Hosts []string
GatewayAddress string
UsernameTemplate string
NetworkAutoDetect int
BandwidthAutoDetect int
ConnectionType int
SplitUserDomain bool
DefaultDomain string
}
func (c *Config) NewApi() {
if len(c.SessionKey) < 32 {
log.Fatal("Session key too small")
}
if len(c.Hosts) < 1 {
log.Fatal("Not enough hosts to connect to specified")
}
c.store = sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
c.stateStore = cache.New(time.Minute*2, 5*time.Minute)
}
func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
state := r.URL.Query().Get("state")
s, found := c.stateStore.Get(state)
if !found {
http.Error(w, "unknown state", http.StatusBadRequest)
return
}
url := s.(string)
ctx := context.Background()
oauth2Token, err := c.OAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
return
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
return
}
idToken, err := c.OIDCTokenVerifier.Verify(ctx, rawIDToken)
if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return
}
resp := struct {
OAuth2Token *oauth2.Token
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
}{oauth2Token, new(json.RawMessage)}
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var data map[string]interface{}
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session, err := c.store.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session.Options.MaxAge = MaxAge
session.Values["preferred_username"] = data["preferred_username"]
session.Values["authenticated"] = true
session.Values["access_token"] = oauth2Token.AccessToken
if err = session.Save(r, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
http.Redirect(w, r, url, http.StatusFound)
}
func (c *Config) Authenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, err := c.store.Get(r, RdpGwSession)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
found := session.Values["authenticated"]
if found == nil || !found.(bool) {
seed := make([]byte, 16)
rand.Read(seed)
state := hex.EncodeToString(seed)
c.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
http.Redirect(w, r, c.OAuth2Config.AuthCodeURL(state), http.StatusFound)
return
}
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userName, ok := ctx.Value("preferred_username").(string)
if !ok {
log.Printf("preferred_username not found in context")
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
return
}
// do a round robin selection for now
rand.Seed(time.Now().Unix())
host := c.Hosts[rand.Intn(len(c.Hosts))]
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
// split the username into user and domain
var user = userName
var domain = c.DefaultDomain
if c.SplitUserDomain {
creds := strings.SplitN(userName, "@", 2)
user = creds[0]
if len(creds) > 1 {
domain = creds[1]
}
}
render := user
if c.UsernameTemplate != "" {
render = fmt.Sprintf(c.UsernameTemplate)
render = strings.Replace(render, "{{ username }}", user, 1)
if c.UsernameTemplate == render {
log.Printf("Invalid username template. %s == %s", c.UsernameTemplate, user)
http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
return
}
}
token, err := c.PAATokenGenerator(ctx, user, host)
if err != nil {
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
}
if c.EnableUserToken {
userToken, err := c.UserTokenGenerator(ctx, user)
if err != nil {
log.Printf("Cannot generate token for user %s due to %s", user, err)
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
}
render = strings.Replace(render, "{{ token }}", userToken, 1)
}
// authenticated
seed := make([]byte, 16)
rand.Read(seed)
fn := hex.EncodeToString(seed) + ".rdp"
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
w.Header().Set("Content-Type", "application/x-rdp")
data := "full address:s:"+host+"\r\n"+
"gatewayhostname:s:"+c.GatewayAddress+"\r\n"+
"gatewaycredentialssource:i:5\r\n"+
"gatewayusagemethod:i:1\r\n"+
"gatewayprofileusagemethod:i:1\r\n"+
"gatewayaccesstoken:s:"+token+"\r\n"+
"networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+
"bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+
"connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+
"username:s:"+render+"\r\n"+
"domain:s:"+domain+"\r\n"+
"bitmapcachesize:i:32000\r\n"+
"smart sizing:i:1\r\n"
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
}

View File

@@ -0,0 +1,60 @@
package common
import (
"context"
"log"
"net"
"net/http"
"strings"
)
const (
ClientIPCtx = "ClientIP"
ProxyAddressesCtx = "ProxyAddresses"
RemoteAddressCtx = "RemoteAddress"
)
func EnrichContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
h := r.Header.Get("X-Forwarded-For")
if h != "" {
var proxies []string
ips := strings.Split(h, ",")
for i := range ips {
ips[i] = strings.TrimSpace(ips[i])
}
clientIp := ips[0]
if len(ips) > 1 {
proxies = ips[1:]
}
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies)
}
ctx = context.WithValue(ctx, RemoteAddressCtx, r.RemoteAddr)
if h == "" {
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func GetClientIp(ctx context.Context) string {
s, ok := ctx.Value(ClientIPCtx).(string)
if !ok {
return ""
}
return s
}
func GetAccessToken(ctx context.Context) string {
token, ok := ctx.Value("access_token").(string)
if !ok {
log.Printf("cannot get access token from context")
return ""
}
return token
}

View File

@@ -0,0 +1,97 @@
package config
import (
"github.com/spf13/viper"
"log"
)
type Configuration struct {
Server ServerConfig
OpenId OpenIDConfig
Caps RDGCapsConfig
Security SecurityConfig
Client ClientConfig
}
type ServerConfig struct {
GatewayAddress string
Port int
CertFile string
KeyFile string
Hosts []string
RoundRobin bool
SessionKey string
SessionEncryptionKey string
SendBuf int
ReceiveBuf int
}
type OpenIDConfig struct {
ProviderUrl string
ClientId string
ClientSecret string
}
type RDGCapsConfig struct {
SmartCardAuth bool
TokenAuth bool
IdleTimeout int
RedirectAll bool
DisableRedirect bool
EnableClipboard bool
EnablePrinter bool
EnablePort bool
EnablePnp bool
EnableDrive bool
}
type SecurityConfig struct {
PAATokenEncryptionKey string
PAATokenSigningKey string
UserTokenEncryptionKey string
UserTokenSigningKey string
VerifyClientIp bool
EnableUserToken bool
}
type ClientConfig struct {
NetworkAutoDetect int
BandwidthAutoDetect int
ConnectionType int
UsernameTemplate string
SplitUserDomain bool
DefaultDomain string
}
func init() {
viper.SetDefault("server.certFile", "server.pem")
viper.SetDefault("server.keyFile", "key.pem")
viper.SetDefault("server.port", 443)
viper.SetDefault("client.networkAutoDetect", 1)
viper.SetDefault("client.bandwidthAutoDetect", 1)
viper.SetDefault("security.verifyClientIp", true)
}
func Load(configFile string) Configuration {
var conf Configuration
viper.SetConfigName("rdpgw")
viper.SetConfigFile(configFile)
viper.AddConfigPath(".")
viper.SetEnvPrefix("RDPGW")
viper.AutomaticEnv()
if err := viper.ReadInConfig(); err != nil {
log.Fatalf("No config file found (%s)", err)
}
if err := viper.Unmarshal(&conf); err != nil {
log.Fatalf("Cannot unmarshal the config file; %s", err)
}
if len(conf.Security.PAATokenSigningKey) < 32 {
log.Fatalf("Token signing key not long enough")
}
return conf
}

148
cmd/rdpgw/main.go Normal file
View File

@@ -0,0 +1,148 @@
package main
import (
"context"
"crypto/tls"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/api"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
"log"
"net/http"
"os"
"strconv"
)
var cmd = &cobra.Command{
Use: "rdpgw",
Long: "Remote Desktop Gateway",
}
var (
configFile string
)
var conf config.Configuration
func main() {
// get config
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
conf = config.Load(configFile)
security.VerifyClientIP = conf.Security.VerifyClientIp
// set security keys
security.SigningKey = []byte(conf.Security.PAATokenSigningKey)
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
// set oidc config
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
if err != nil {
log.Fatalf("Cannot get oidc provider: %s", err)
}
oidcConfig := &oidc.Config{
ClientID: conf.OpenId.ClientId,
}
verifier := provider.Verifier(oidcConfig)
oauthConfig := oauth2.Config{
ClientID: conf.OpenId.ClientId,
ClientSecret: conf.OpenId.ClientSecret,
RedirectURL: "https://" + conf.Server.GatewayAddress + "/callback",
Endpoint: provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
security.OIDCProvider = provider
security.Oauth2Config = oauthConfig
api := &api.Config{
GatewayAddress: conf.Server.GatewayAddress,
OAuth2Config: &oauthConfig,
OIDCTokenVerifier: verifier,
PAATokenGenerator: security.GeneratePAAToken,
UserTokenGenerator: security.GenerateUserToken,
EnableUserToken: conf.Security.EnableUserToken,
SessionKey: []byte(conf.Server.SessionKey),
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
Hosts: conf.Server.Hosts,
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
UsernameTemplate: conf.Client.UsernameTemplate,
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
ConnectionType: conf.Client.ConnectionType,
SplitUserDomain: conf.Client.SplitUserDomain,
DefaultDomain: conf.Client.DefaultDomain,
}
api.NewApi()
if conf.Server.CertFile == "" || conf.Server.KeyFile == "" {
log.Fatal("Both certfile and keyfile need to be specified")
}
//mux := http.NewServeMux()
//mux.HandleFunc("*", HelloServer)
log.Printf("Starting remote desktop gateway server")
cfg := &tls.Config{}
tlsDebug := os.Getenv("SSLKEYLOGFILE")
if tlsDebug != "" {
w, err := os.OpenFile(tlsDebug, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
log.Fatalf("Cannot open key log file %s for writing %s", tlsDebug, err)
}
log.Printf("Key log file set to: %s", tlsDebug)
cfg.KeyLogWriter = w
}
cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile)
if err != nil {
log.Fatal(err)
}
cfg.Certificates = append(cfg.Certificates, cert)
server := http.Server{
Addr: ":" + strconv.Itoa(conf.Server.Port),
TLSConfig: cfg,
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
}
// create the gateway
handlerConfig := protocol.ServerConf{
IdleTimeout: conf.Caps.IdleTimeout,
TokenAuth: conf.Caps.TokenAuth,
SmartCardAuth: conf.Caps.SmartCardAuth,
RedirectFlags: protocol.RedirectFlags{
Clipboard: conf.Caps.EnableClipboard,
Drive: conf.Caps.EnableDrive,
Printer: conf.Caps.EnablePrinter,
Port: conf.Caps.EnablePort,
Pnp: conf.Caps.EnablePnp,
DisableAll: conf.Caps.DisableRedirect,
EnableAll: conf.Caps.RedirectAll,
},
VerifyTunnelCreate: security.VerifyPAAToken,
VerifyServerFunc: security.VerifyServerFunc,
SendBuf: conf.Server.SendBuf,
ReceiveBuf: conf.Server.ReceiveBuf,
}
gw := protocol.Gateway{
ServerConf: &handlerConfig,
}
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
http.Handle("/metrics", promhttp.Handler())
http.HandleFunc("/tokeninfo", api.TokenInfo)
http.HandleFunc("/callback", api.HandleCallback)
err = server.ListenAndServeTLS("", "")
if err != nil {
log.Fatal("ListenAndServe: ", err)
}
}

View File

@@ -0,0 +1,246 @@
package protocol
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"log"
"net"
)
const (
MajorVersion = 0x0
MinorVersion = 0x0
Version = 0x00
)
type ClientConfig struct {
SmartCardAuth bool
PAAToken string
NTLMAuth bool
Session *SessionInfo
LocalConn net.Conn
Server string
Port int
Name string
}
func (c *ClientConfig) ConnectAndForward() error {
c.Session.TransportOut.WritePacket(c.handshakeRequest())
for {
pt, sz, pkt, err := readMessage(c.Session.TransportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
}
switch pt {
case PKT_TYPE_HANDSHAKE_RESPONSE:
caps, err := c.handshakeResponse(pkt)
if err != nil {
log.Printf("Cannot connect to %s due to %s", c.Server, err)
return err
}
log.Printf("Handshake response received. Caps: %d", caps)
c.Session.TransportOut.WritePacket(c.tunnelRequest())
case PKT_TYPE_TUNNEL_RESPONSE:
tid, caps, err := c.tunnelResponse(pkt)
if err != nil {
log.Printf("Cannot setup tunnel due to %s", err)
return err
}
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
flags, timeout, err := c.tunnelAuthResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
c.Session.TransportOut.WritePacket(c.channelRequest())
case PKT_TYPE_CHANNEL_RESPONSE:
cid, err := c.channelResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
if cid < 1 {
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
}
log.Printf("Channel creation succesful. Channel id: %d", cid)
go forward(c.LocalConn, c.Session.TransportOut)
case PKT_TYPE_DATA:
receive(pkt, c.LocalConn)
default:
log.Printf("Unknown packet type received: %d size %d", pt, sz)
}
}
}
func (c *ClientConfig) handshakeRequest() []byte {
var caps uint16
if c.SmartCardAuth {
caps = caps | HTTP_EXTENDED_AUTH_SC
}
if len(c.PAAToken) > 0 {
caps = caps | HTTP_EXTENDED_AUTH_PAA
}
if c.NTLMAuth {
caps = caps | HTTP_EXTENDED_AUTH_SSPI_NTLM
}
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, byte(MajorVersion))
binary.Write(buf, binary.LittleEndian, byte(MinorVersion))
binary.Write(buf, binary.LittleEndian, uint16(Version))
binary.Write(buf, binary.LittleEndian, uint16(caps))
return createPacket(PKT_TYPE_HANDSHAKE_REQUEST, buf.Bytes())
}
func (c *ClientConfig) handshakeResponse(data []byte) (caps uint16, err error) {
var errorCode int32
var major byte
var minor byte
var version uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &errorCode)
binary.Read(r, binary.LittleEndian, &major)
binary.Read(r, binary.LittleEndian, &minor)
binary.Read(r, binary.LittleEndian, &version)
binary.Read(r, binary.LittleEndian, &caps)
if errorCode > 0 {
return 0, fmt.Errorf("error code: %d", errorCode)
}
return caps, nil
}
func (c *ClientConfig) tunnelRequest() []byte {
buf := new(bytes.Buffer)
var caps uint32
var size uint16
var fields uint16
if len(c.PAAToken) > 0 {
fields = fields | HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE
}
caps = caps | HTTP_CAPABILITY_IDLE_TIMEOUT
binary.Write(buf, binary.LittleEndian, caps)
binary.Write(buf, binary.LittleEndian, fields)
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
if len(c.PAAToken) > 0 {
utf16Token := EncodeUTF16(c.PAAToken)
size = uint16(len(utf16Token))
binary.Write(buf, binary.LittleEndian, size)
buf.Write(utf16Token)
}
return createPacket(PKT_TYPE_TUNNEL_CREATE, buf.Bytes())
}
func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32, err error) {
var version uint16
var errorCode uint32
var fields uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &version)
binary.Read(r, binary.LittleEndian, &errorCode)
binary.Read(r, binary.LittleEndian, &fields)
r.Seek(2, io.SeekCurrent)
if (fields & HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID) == HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID {
binary.Read(r, binary.LittleEndian, &tunnelId)
}
if (fields & HTTP_TUNNEL_RESPONSE_FIELD_CAPS) == HTTP_TUNNEL_RESPONSE_FIELD_CAPS {
binary.Read(r, binary.LittleEndian, &caps)
}
if errorCode != 0 {
err = fmt.Errorf("tunnel error %d", errorCode)
}
return
}
func (c *ClientConfig) tunnelAuthRequest() []byte {
utf16name := EncodeUTF16(c.Name)
size := uint16(len(utf16name))
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, size)
buf.Write(utf16name)
return createPacket(PKT_TYPE_TUNNEL_AUTH, buf.Bytes())
}
func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout uint32, err error) {
var errorCode uint32
var fields uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &errorCode)
binary.Read(r, binary.LittleEndian, &fields)
r.Seek(2, io.SeekCurrent)
if (fields & HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS) == HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS {
binary.Read(r, binary.LittleEndian, &flags)
}
if (fields & HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT) == HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT {
binary.Read(r, binary.LittleEndian, &timeout)
}
if errorCode > 0 {
return 0, 0, fmt.Errorf("tunnel auth error %d", errorCode)
}
return
}
func (c *ClientConfig) channelRequest() []byte {
utf16server := EncodeUTF16(c.Server)
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
binary.Write(buf, binary.LittleEndian, uint16(c.Port))
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
buf.Write(utf16server)
return createPacket(PKT_TYPE_CHANNEL_CREATE, buf.Bytes())
}
func (c *ClientConfig) channelResponse(data []byte) (channelId uint32, err error) {
var errorCode uint32
var fields uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &errorCode)
binary.Read(r, binary.LittleEndian, &fields)
r.Seek(2, io.SeekCurrent)
if (fields & HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID) == HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID {
binary.Read(r, binary.LittleEndian, &channelId)
}
if errorCode > 0 {
return 0, fmt.Errorf("channel response error %d", errorCode)
}
return channelId, nil
}

View File

@@ -0,0 +1,148 @@
package protocol
import (
"bytes"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"io"
"log"
"net"
"os"
"syscall"
)
type RedirectFlags struct {
Clipboard bool
Port bool
Drive bool
Printer bool
Pnp bool
DisableAll bool
EnableAll bool
}
type SessionInfo struct {
// The connection-id (RDG-ConnID) as reported by the client
ConnId string
// The underlying incoming transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportIn
TransportIn transport.Transport
// The underlying outgoing transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportOut
TransportOut transport.Transport
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
RemoteServer string
// The obtained client ip address
ClientIp string
}
// readMessage parses and defragments a packet from a Transport. It returns
// at most the bytes that have been reported by the packet
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := in.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
// createPacket wraps the data into the protocol packet
func createPacket(pktType uint16, data []byte) (packet []byte) {
size := len(data) + 8
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(pktType))
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint32(size))
buf.Write(data)
return buf.Bytes()
}
// readHeader parses a packet and verifies its reported size
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min
if len(data) < 8 {
return 0, 0, nil, errors.New("header too short, fragment likely")
}
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &packetType)
r.Seek(4, io.SeekStart)
binary.Read(r, binary.LittleEndian, &size)
if len(data) < int(size) {
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
}
return packetType, size, data[8:size], nil
}
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol
func forward(in net.Conn, out transport.Transport) {
defer in.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := in.Read(buf)
if err != nil {
log.Printf("Error reading from local conn %s", err)
break
}
binary.Write(b1, binary.LittleEndian, uint16(n))
b1.Write(buf[:n])
out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
// receive data received from the gateway client, unwrap and forward the remote desktop server
func receive(data []byte, out net.Conn) {
buf := bytes.NewReader(data)
var cblen uint16
binary.Read(buf, binary.LittleEndian, &cblen)
pkt := make([]byte, cblen)
binary.Read(buf, binary.LittleEndian, &pkt)
out.Write(pkt)
}
// wrapSyscallError takes an error and a syscall name. If the error is
// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name.
func wrapSyscallError(name string, err error) error {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError(name, err)
}
return err
}

View File

@@ -0,0 +1,213 @@
package protocol
import (
"context"
"errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus"
"log"
"net"
"net/http"
"reflect"
"syscall"
"time"
)
const (
rdgConnectionIdKey = "Rdg-Connection-Id"
MethodRDGIN = "RDG_IN_DATA"
MethodRDGOUT = "RDG_OUT_DATA"
)
var (
connectionCache = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "connection_cache",
Help: "The amount of connections in the cache",
})
websocketConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "websocket_connections",
Help: "The count of websocket connections",
})
legacyConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "rdpgw",
Name: "legacy_connections",
Help: "The count of legacy https connections",
})
)
type Gateway struct {
ServerConf *ServerConf
}
var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute)
func init() {
prometheus.MustRegister(connectionCache)
prometheus.MustRegister(legacyConnections)
prometheus.MustRegister(websocketConnections)
}
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
connectionCache.Set(float64(c.ItemCount()))
var s *SessionInfo
connId := r.Header.Get(rdgConnectionIdKey)
x, found := c.Get(connId)
if !found {
s = &SessionInfo{ConnId: connId}
} else {
s = x.(*SessionInfo)
}
ctx := context.WithValue(r.Context(), "SessionInfo", s)
if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
return
}
r.Method = "GET" // force
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
return
}
defer conn.Close()
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
if err != nil {
log.Printf("Cannot set send/receive buffers: %s", err)
}
g.handleWebsocketProtocol(ctx, conn, s)
} else if r.Method == MethodRDGIN {
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
}
}
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 {
return nil
}
// conn == tls.Conn
ptr := reflect.ValueOf(conn)
val := reflect.Indirect(ptr)
if val.Kind() != reflect.Struct {
return errors.New("didn't get a struct from conn")
}
// this gets net.Conn -> *net.TCPConn -> net.TCPConn
ptrConn := val.FieldByName("conn")
valConn := reflect.Indirect(ptrConn)
if !valConn.IsValid() {
return errors.New("cannot find conn field")
}
valConn = valConn.Elem().Elem()
// net.FD
ptrNetFd := valConn.FieldByName("fd")
valNetFd := reflect.Indirect(ptrNetFd)
if !valNetFd.IsValid() {
return errors.New("cannot find fd field")
}
// pfd member
ptrPfd := valNetFd.FieldByName("pfd")
valPfd := reflect.Indirect(ptrPfd)
if !valPfd.IsValid() {
return errors.New("cannot find pfd field")
}
// finally the exported Sysfd
ptrSysFd := valPfd.FieldByName("Sysfd")
if !ptrSysFd.IsValid() {
return errors.New("cannot find Sysfd field")
}
fd := int(ptrSysFd.Int())
if g.ServerConf.ReceiveBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf)
if err != nil {
return wrapSyscallError("setsockopt", err)
}
}
if g.ServerConf.SendBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf)
if err != nil {
return wrapSyscallError("setsockopt", err)
}
}
return nil
}
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
websocketConnections.Inc()
defer websocketConnections.Dec()
inout, _ := transport.NewWS(c)
s.TransportOut = inout
s.TransportIn = inout
handler := NewServer(s, g.ServerConf)
handler.Process(ctx)
}
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
// to ensure the connections do not get cached or terminated by a proxy prematurely.
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) {
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
if r.Method == MethodRDGOUT {
out, err := transport.NewLegacy(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
s.TransportOut = out
out.SendAccept(true)
c.Set(s.ConnId, s, cache.DefaultExpiration)
} else if r.Method == MethodRDGIN {
legacyConnections.Inc()
defer legacyConnections.Dec()
in, err := transport.NewLegacy(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
return
}
defer in.Close()
if s.TransportIn == nil {
s.TransportIn = in
c.Set(s.ConnId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
in.SendAccept(false)
// read some initial data
in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
handler := NewServer(s, g.ServerConf)
handler.Process(r.Context())
}
}
}

View File

@@ -0,0 +1,211 @@
package protocol
import (
"fmt"
"log"
"testing"
)
const (
HeaderLen = 8
HandshakeRequestLen = HeaderLen + 6
HandshakeResponseLen = HeaderLen + 10
TunnelCreateRequestLen = HeaderLen + 8 // + dynamic
TunnelCreateResponseLen = HeaderLen + 18
TunnelAuthLen = HeaderLen + 2 // + dynamic
TunnelAuthResponseLen = HeaderLen + 16
ChannelCreateLen = HeaderLen + 8 // + dynamic
ChannelResponseLen = HeaderLen + 12
)
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
pt, size, pkt, err := readHeader(data)
if pt != expPt {
return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt)
}
if size != expSize {
return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected size %d, got %d", expSize, size)
}
if err != nil {
return 0, 0, []byte{}, err
}
return pt, size, pkt, nil
}
func TestHandshake(t *testing.T) {
client := ClientConfig{
PAAToken: "abab",
}
s := &SessionInfo{}
hc := &ServerConf{
TokenAuth: true,
}
h := NewServer(s, hc)
data := client.handshakeRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_REQUEST, HandshakeRequestLen)
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
log.Printf("pkt: %x", pkt)
major, minor, version, extAuth := h.handshakeRequest(pkt)
if major != MajorVersion || minor != MinorVersion || version != Version {
t.Fatalf("handshakeRequest failed got version %d.%d protocol %d, expected %d.%d protocol %d",
major, minor, version, MajorVersion, MinorVersion, Version)
}
if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
t.Fatalf("handshakeRequest failed got ext auth %d, expected %d", extAuth, extAuth|HTTP_EXTENDED_AUTH_PAA)
}
data = h.handshakeResponse(0x0, 0x0)
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen)
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
log.Printf("pkt: %x", pkt)
caps, err := client.handshakeResponse(pkt)
if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps|HTTP_EXTENDED_AUTH_PAA)
}
}
func TestTunnelCreation(t *testing.T) {
client := ClientConfig{
PAAToken: "abab",
}
s := &SessionInfo{}
hc := &ServerConf{
TokenAuth: true,
}
h := NewServer(s, hc)
data := client.tunnelRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
uint32(TunnelCreateRequestLen+2+len(client.PAAToken)*2))
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
caps, token := h.tunnelRequest(pkt)
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
t.Fatalf("tunnelRequest failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT)
}
if token != client.PAAToken {
t.Fatalf("tunnelRequest failed got token %s, expected %s", token, client.PAAToken)
}
data = h.tunnelResponse()
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen)
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
tid, caps, err := client.tunnelResponse(pkt)
if err != nil {
t.Fatalf("Error %s", err)
}
if tid != tunnelId {
t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId)
}
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT)
}
}
func TestTunnelAuth(t *testing.T) {
name := "test_name"
client := ClientConfig{
Name: name,
}
s := &SessionInfo{}
hc := &ServerConf{
TokenAuth: true,
IdleTimeout: 10,
RedirectFlags: RedirectFlags{
Clipboard: true,
},
}
h := NewServer(s, hc)
data := client.tunnelAuthRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
n := h.tunnelAuthRequest(pkt)
if n != name {
t.Fatalf("tunnelAuthRequest failed got name %s, expected %s", n, name)
}
data = h.tunnelAuthResponse()
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH_RESPONSE, TunnelAuthResponseLen)
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
flags, timeout, err := client.tunnelAuthResponse(pkt)
if err != nil {
t.Fatalf("tunnel auth error %s", err)
}
if (flags & HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) == HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD {
t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d",
flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD)
}
if int(timeout) != hc.IdleTimeout {
t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d",
timeout, hc.IdleTimeout)
}
}
func TestChannelCreation(t *testing.T) {
server := "test_server"
client := ClientConfig{
Server: server,
Port: 3389,
}
s := &SessionInfo{}
hc := &ServerConf{
TokenAuth: true,
IdleTimeout: 10,
RedirectFlags: RedirectFlags{
Clipboard: true,
},
}
h := NewServer(s, hc)
data := client.channelRequest()
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
hServer, hPort := h.channelRequest(pkt)
if hServer != server {
t.Fatalf("channelRequest failed got server %s, expected %s", hServer, server)
}
if int(hPort) != client.Port {
t.Fatalf("channelRequest failed got port %d, expected %d", hPort, client.Port)
}
data = h.channelResponse()
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_CHANNEL_RESPONSE, uint32(ChannelResponseLen))
if err != nil {
t.Fatalf("verifyHeader failed: %s", err)
}
channelId, err := client.channelResponse(pkt)
if err != nil {
t.Fatalf("channelResponse failed: %s", err)
}
if channelId < 1 {
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
}
}

View File

@@ -0,0 +1,340 @@
package protocol
import (
"bytes"
"context"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"io"
"log"
"net"
"strconv"
"time"
)
type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(context.Context, string) (bool, error)
type Server struct {
Session *SessionInfo
VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc
RedirectFlags int
IdleTimeout int
SmartCardAuth bool
TokenAuth bool
ClientName string
Remote net.Conn
State int
}
type ServerConf struct {
VerifyTunnelCreate VerifyTunnelCreate
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
VerifyServerFunc VerifyServerFunc
RedirectFlags RedirectFlags
IdleTimeout int
SmartCardAuth bool
TokenAuth bool
ReceiveBuf int
SendBuf int
}
func NewServer(s *SessionInfo, conf *ServerConf) *Server {
h := &Server{
State: SERVER_STATE_INITIAL,
Session: s,
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
IdleTimeout: conf.IdleTimeout,
SmartCardAuth: conf.SmartCardAuth,
TokenAuth: conf.TokenAuth,
VerifyTunnelCreate: conf.VerifyTunnelCreate,
VerifyServerFunc: conf.VerifyServerFunc,
VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc,
}
return h
}
const tunnelId = 10
func (s *Server) Process(ctx context.Context) error {
for {
pt, sz, pkt, err := readMessage(s.Session.TransportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
}
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
if s.State != SERVER_STATE_INITIAL {
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
return errors.New("wrong state")
}
major, minor, _, _ := s.handshakeRequest(pkt) // todo check if auth matches what the handler can do
msg := s.handshakeResponse(major, minor)
s.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_HANDSHAKE
case PKT_TYPE_TUNNEL_CREATE:
log.Printf("Tunnel create")
if s.State != SERVER_STATE_HANDSHAKE {
log.Printf("Tunnel create attempted while in wrong state %d != %d",
s.State, SERVER_STATE_HANDSHAKE)
return errors.New("wrong state")
}
_, cookie := s.tunnelRequest(pkt)
if s.VerifyTunnelCreate != nil {
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
return errors.New("invalid PAA cookie")
}
}
msg := s.tunnelResponse()
s.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_TUNNEL_CREATE
case PKT_TYPE_TUNNEL_AUTH:
log.Printf("Tunnel auth")
if s.State != SERVER_STATE_TUNNEL_CREATE {
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
s.State, SERVER_STATE_TUNNEL_CREATE)
return errors.New("wrong state")
}
client := s.tunnelAuthRequest(pkt)
if s.VerifyTunnelAuthFunc != nil {
if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok {
log.Printf("Invalid client name: %s", client)
return errors.New("invalid client name")
}
}
msg := s.tunnelAuthResponse()
s.Session.TransportOut.WritePacket(msg)
s.State = SERVER_STATE_TUNNEL_AUTHORIZE
case PKT_TYPE_CHANNEL_CREATE:
log.Printf("Channel create")
if s.State != SERVER_STATE_TUNNEL_AUTHORIZE {
log.Printf("Channel create attempted while in wrong state %d != %d",
s.State, SERVER_STATE_TUNNEL_AUTHORIZE)
return errors.New("wrong state")
}
server, port := s.channelRequest(pkt)
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
if s.VerifyServerFunc != nil {
if ok, _ := s.VerifyServerFunc(ctx, host); !ok {
log.Printf("Not allowed to connect to %s by policy handler", host)
return errors.New("denied by security policy")
}
}
log.Printf("Establishing connection to RDP server: %s", host)
s.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
if err != nil {
log.Printf("Error connecting to %s, %s", host, err)
return err
}
log.Printf("Connection established")
msg := s.channelResponse()
s.Session.TransportOut.WritePacket(msg)
// Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually
go forward(s.Remote, s.Session.TransportOut)
s.State = SERVER_STATE_CHANNEL_CREATE
case PKT_TYPE_DATA:
if s.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state")
}
s.State = SERVER_STATE_OPENED
receive(pkt, s.Remote)
case PKT_TYPE_KEEPALIVE:
// keepalives can be received while the channel is not open yet
if s.State < SERVER_STATE_CHANNEL_CREATE {
log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
return errors.New("wrong state")
}
// avoid concurrency issues
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
log.Printf("Close channel")
if s.State != SERVER_STATE_OPENED {
log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED)
return errors.New("wrong state")
}
s.Session.TransportIn.Close()
s.Session.TransportOut.Close()
s.State = SERVER_STATE_CLOSED
default:
log.Printf("Unknown packet (size %d): %x", sz, pkt)
}
}
}
// Creates a packet the is a response to a handshakeRequest request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure
func (s *Server) handshakeResponse(major byte, minor byte) []byte {
var caps uint16
if s.SmartCardAuth {
caps = caps | HTTP_EXTENDED_AUTH_SC
}
if s.TokenAuth {
caps = caps | HTTP_EXTENDED_AUTH_PAA
}
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
buf.Write([]byte{major, minor})
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint16(caps)) // extended auth
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
}
func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &major)
binary.Read(r, binary.LittleEndian, &minor)
binary.Read(r, binary.LittleEndian, &version)
binary.Read(r, binary.LittleEndian, &extAuth)
log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth)
return
}
func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
var fields uint16
r := bytes.NewReader(data)
binary.Read(r, binary.LittleEndian, &caps)
binary.Read(r, binary.LittleEndian, &fields)
r.Seek(2, io.SeekCurrent)
if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE {
var size uint16
binary.Read(r, binary.LittleEndian, &size)
cookieB := make([]byte, size)
r.Read(cookieB)
cookie, _ = DecodeUTF16(cookieB)
}
return
}
func (s *Server) tunnelResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// tunnel id (when is it used?)
binary.Write(buf, binary.LittleEndian, uint32(tunnelId))
binary.Write(buf, binary.LittleEndian, uint32(HTTP_CAPABILITY_IDLE_TIMEOUT))
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
}
func (s *Server) tunnelAuthRequest(data []byte) string {
buf := bytes.NewReader(data)
var size uint16
binary.Read(buf, binary.LittleEndian, &size)
clData := make([]byte, size)
binary.Read(buf, binary.LittleEndian, &clData)
clientName, _ := DecodeUTF16(clData)
return clientName
}
func (s *Server) tunnelAuthResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// idle timeout
if s.IdleTimeout < 0 {
s.IdleTimeout = 0
}
binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags
binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
}
func (s *Server) channelRequest(data []byte) (server string, port uint16) {
buf := bytes.NewReader(data)
var resourcesSize byte
var alternative byte
var protocol uint16
var nameSize uint16
binary.Read(buf, binary.LittleEndian, &resourcesSize)
binary.Read(buf, binary.LittleEndian, &alternative)
binary.Read(buf, binary.LittleEndian, &port)
binary.Read(buf, binary.LittleEndian, &protocol)
binary.Read(buf, binary.LittleEndian, &nameSize)
nameData := make([]byte, nameSize)
binary.Read(buf, binary.LittleEndian, &nameData)
server, _ = DecodeUTF16(nameData)
return
}
func (s *Server) channelResponse() []byte {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID)) // fields present
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
// channel id is required for Windows clients
binary.Write(buf, binary.LittleEndian, uint32(1)) // channel id
// optional fields
// channel id uint32 (4)
// udp port uint16 (2)
// udp auth cookie 1 byte for side channel
// length uint16
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
}
func makeRedirectFlags(flags RedirectFlags) int {
var redir = 0
if flags.DisableAll {
return HTTP_TUNNEL_REDIR_DISABLE_ALL
}
if flags.EnableAll {
return HTTP_TUNNEL_REDIR_ENABLE_ALL
}
if !flags.Port {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
}
if !flags.Clipboard {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
}
if !flags.Drive {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
}
if !flags.Pnp {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
}
if !flags.Printer {
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
}
return redir
}

View File

@@ -0,0 +1,78 @@
package protocol
const (
PKT_TYPE_HANDSHAKE_REQUEST = 0x1
PKT_TYPE_HANDSHAKE_RESPONSE = 0x2
PKT_TYPE_EXTENDED_AUTH_MSG = 0x3
PKT_TYPE_TUNNEL_CREATE = 0x4
PKT_TYPE_TUNNEL_RESPONSE = 0x5
PKT_TYPE_TUNNEL_AUTH = 0x6
PKT_TYPE_TUNNEL_AUTH_RESPONSE = 0x7
PKT_TYPE_CHANNEL_CREATE = 0x8
PKT_TYPE_CHANNEL_RESPONSE = 0x9
PKT_TYPE_DATA = 0xA
PKT_TYPE_SERVICE_MESSAGE = 0xB
PKT_TYPE_REAUTH_MESSAGE = 0xC
PKT_TYPE_KEEPALIVE = 0xD
PKT_TYPE_CLOSE_CHANNEL = 0x10
PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11
)
const (
HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID = 0x01
HTTP_TUNNEL_RESPONSE_FIELD_CAPS = 0x02
HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ = 0x04
HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10
)
const (
HTTP_EXTENDED_AUTH_NONE = 0x0
HTTP_EXTENDED_AUTH_SC = 0x1 /* Smart card authentication. */
HTTP_EXTENDED_AUTH_PAA = 0x02 /* Pluggable authentication. */
HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */
)
const (
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS = 0x01
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04
)
const (
HTTP_TUNNEL_REDIR_ENABLE_ALL = 0x80000000
HTTP_TUNNEL_REDIR_DISABLE_ALL = 0x40000000
HTTP_TUNNEL_REDIR_DISABLE_DRIVE = 0x01
HTTP_TUNNEL_REDIR_DISABLE_PRINTER = 0x02
HTTP_TUNNEL_REDIR_DISABLE_PORT = 0x04
HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08
HTTP_TUNNEL_REDIR_DISABLE_PNP = 0x10
)
const (
HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID = 0x01
HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02
HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT = 0x04
)
const (
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
)
const (
SERVER_STATE_INITIAL = 0x0
SERVER_STATE_HANDSHAKE = 0x1
SERVER_STATE_TUNNEL_CREATE = 0x2
SERVER_STATE_TUNNEL_AUTHORIZE = 0x3
SERVER_STATE_CHANNEL_CREATE = 0x4
SERVER_STATE_OPENED = 0x5
SERVER_STATE_CLOSED = 0x6
)
const (
HTTP_CAPABILITY_TYPE_QUAR_SOH = 0x1
HTTP_CAPABILITY_IDLE_TIMEOUT = 0x2
HTTP_CAPABILITY_MESSAGING_CONSENT_SIGN = 0x4
HTTP_CAPABILITY_MESSAGING_SERVICE_MSG = 0x8
HTTP_CAPABILITY_REAUTH = 0x10
HTTP_CAPABILITY_UDP_TRANSPORT = 0x20
)

View File

@@ -0,0 +1,42 @@
package protocol
import (
"bytes"
"encoding/binary"
"fmt"
"unicode/utf16"
"unicode/utf8"
)
func DecodeUTF16(b []byte) (string, error) {
if len(b)%2 != 0 {
return "", fmt.Errorf("must have even length byte slice")
}
u16s := make([]uint16, 1)
ret := &bytes.Buffer{}
b8buf := make([]byte, 4)
lb := len(b)
for i := 0; i < lb; i += 2 {
u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8)
r := utf16.Decode(u16s)
n := utf8.EncodeRune(b8buf, r[0])
ret.Write(b8buf[:n])
}
bret := ret.Bytes()
if len(bret) > 0 && bret[len(bret)-1] == '\x00' {
bret = bret[:len(bret)-1]
}
return string(bret), nil
}
func EncodeUTF16(s string) []byte {
ret := new(bytes.Buffer)
enc := utf16.Encode([]rune(s))
for c := range enc {
binary.Write(ret, binary.LittleEndian, enc[c])
}
return ret.Bytes()
}

240
cmd/rdpgw/security/jwt.go Normal file
View File

@@ -0,0 +1,240 @@
package security
import (
"context"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/square/go-jose/v3"
"github.com/square/go-jose/v3/jwt"
"golang.org/x/oauth2"
"log"
"time"
)
var (
SigningKey []byte
EncryptionKey []byte
UserSigningKey []byte
UserEncryptionKey []byte
OIDCProvider *oidc.Provider
Oauth2Config oauth2.Config
)
var ExpiryTime time.Duration = 5
var VerifyClientIP bool = true
type customClaims struct {
RemoteServer string `json:"remoteServer"`
ClientIP string `json:"clientIp"`
AccessToken string `json:"accessToken"`
}
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
token, err := jwt.ParseSigned(tokenString)
// check if the signing algo matches what we expect
for _, header := range token.Headers {
if header.Algorithm != string(jose.HS256) {
return false, fmt.Errorf("unexpected signing method: %v", header.Algorithm)
}
}
standard := jwt.Claims{}
custom := customClaims{}
// Claims automagically checks the signature...
err = token.Claims(SigningKey, &standard, &custom)
if err != nil {
log.Printf("token signature validation failed due to %s", err)
return false, err
}
// ...but doesn't check the expiry claim :/
err = standard.Validate(jwt.Expected{
Issuer: "rdpgw",
Time: time.Now(),
})
if err != nil {
log.Printf("token validation failed due to %s", err)
return false, err
}
// validate the access token
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
_, err = OIDCProvider.UserInfo(ctx, tokenSource)
if err != nil {
log.Printf("Cannot get user info for access token: %s", err)
return false, err
}
s := getSessionInfo(ctx)
s.RemoteServer = custom.RemoteServer
s.ClientIp = custom.ClientIP
return true, nil
}
func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
s := getSessionInfo(ctx)
if s == nil {
return false, errors.New("no valid session info found in context")
}
if s.RemoteServer != host {
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
return false, nil
}
if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) {
log.Printf("Current client ip address %s does not match token client ip %s",
common.GetClientIp(ctx), s.ClientIp)
return false, nil
}
return true, nil
}
func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) {
if len(SigningKey) < 32 {
return "", errors.New("token signing key not long enough or not specified")
}
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil)
if err != nil {
log.Printf("Cannot obtain signer %s", err)
return "", err
}
standard := jwt.Claims{
Issuer: "rdpgw",
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
Subject: username,
}
private := customClaims{
RemoteServer: server,
ClientIP: common.GetClientIp(ctx),
AccessToken: common.GetAccessToken(ctx),
}
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
log.Printf("Cannot sign PAA token %s", err)
return "", err
} else {
return token, nil
}
}
func GenerateUserToken(ctx context.Context, userName string) (string, error) {
if len(UserEncryptionKey) < 32 {
return "", errors.New("user token encryption key not long enough or not specified")
}
claims := jwt.Claims{
Subject: userName,
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
Issuer: "rdpgw",
}
enc, err := jose.NewEncrypter(
jose.A128CBC_HS256,
jose.Recipient{Algorithm: jose.DIRECT, Key: UserEncryptionKey},
(&jose.EncrypterOptions{Compression: jose.DEFLATE}).WithContentType("JWT"),
)
if err != nil {
log.Printf("Cannot encrypt user token due to %s", err)
return "", err
}
// this makes the token bigger and we deal with a limited space of 511 characters
// sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil)
// token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).CompactSerialize()
token, err := jwt.Encrypted(enc).Claims(claims).CompactSerialize()
return token, err
}
func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
standard := jwt.Claims{}
if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 {
enc, err := jwt.ParseSignedAndEncrypted(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
token, err := enc.Decrypt(UserEncryptionKey)
if err != nil {
log.Printf("Cannot decrypt token %s", err)
return standard, errors.New("cannot decrypt token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return standard, errors.New("signature validation failure")
}
if err = token.Claims(UserSigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return standard, errors.New("cannot verify signature")
}
} else if len(UserSigningKey) == 0 {
token, err := jwt.ParseEncrypted(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
err = token.Claims(UserEncryptionKey, &standard)
if err != nil {
log.Printf("Cannot decrypt token %s", err)
return standard, errors.New("cannot decrypt token")
}
} else {
token, err := jwt.ParseSigned(token)
if err != nil {
log.Printf("Cannot get token %s", err)
return standard, errors.New("cannot get token")
}
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
log.Printf("signature validation failure: %s", err)
return standard, errors.New("signature validation failure")
}
err = token.Claims(UserSigningKey, &standard)
if err = token.Claims(UserSigningKey, &standard); err != nil {
log.Printf("cannot verify signature %s", err)
return standard, errors.New("cannot verify signature")
}
}
// go-jose doesnt verify the expiry
err := standard.Validate(jwt.Expected{
Issuer: "rdpgw",
Time: time.Now(),
})
if err != nil {
log.Printf("token validation failed due to %s", err)
return standard, fmt.Errorf("token validation failed due to %s", err)
}
return standard, nil
}
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
if !ok {
log.Printf("cannot get session info from context")
return nil
}
return s
}
func verifyAlg(headers []jose.Header, alg string) (bool, error) {
for _, header := range headers {
if header.Algorithm != alg {
return false, fmt.Errorf("invalid signing method %s", header.Algorithm)
}
}
return true, nil
}

View File

@@ -0,0 +1,82 @@
package transport
import (
"bufio"
"errors"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"time"
)
const (
crlf = "\r\n"
HttpOK = "HTTP/1.1 200 OK\r\n"
)
type LegacyPKT struct {
Conn net.Conn
ChunkedReader io.Reader
Writer *bufio.Writer
}
func NewLegacy(w http.ResponseWriter) (*LegacyPKT, error) {
hj, ok := w.(http.Hijacker)
if ok {
conn, rw, err := hj.Hijack()
l := &LegacyPKT{
Conn: conn,
ChunkedReader: httputil.NewChunkedReader(rw.Reader),
Writer: rw.Writer,
}
return l, err
}
return nil, errors.New("cannot hijack connection")
}
func (t *LegacyPKT) ReadPacket() (n int, p []byte, err error){
buf := make([]byte, 4096) // bufio.defaultBufSize
n, err = t.ChunkedReader.Read(buf)
p = make([]byte, n)
copy(p, buf)
return n, p, err
}
func (t *LegacyPKT) WritePacket(b []byte) (n int, err error) {
return t.Conn.Write(b)
}
func (t *LegacyPKT) Close() error {
return t.Conn.Close()
}
// [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0
// The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes).
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
// not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the
// connection after the last byte is sent.
func (t *LegacyPKT) SendAccept(doSeed bool) {
t.Writer.WriteString(HttpOK)
t.Writer.WriteString("Date: " + time.Now().Format(time.RFC1123) + crlf)
if !doSeed {
t.Writer.WriteString("Content-Length: 0" + crlf)
}
t.Writer.WriteString(crlf)
if doSeed {
seed := make([]byte, 10)
rand.Read(seed)
// docs say it's a seed but 2019 responds with ab cd * 5
t.Writer.Write(seed)
}
t.Writer.Flush()
}
func (t *LegacyPKT) Drain() {
p := make([]byte, 32767)
t.Conn.Read(p)
}

View File

@@ -0,0 +1,8 @@
package transport
type Transport interface {
ReadPacket() (n int, p []byte, err error)
WritePacket(b []byte) (n int, err error)
Close() error
}

View File

@@ -0,0 +1,42 @@
package transport
import (
"errors"
"github.com/gorilla/websocket"
)
type WSPKT struct {
Conn *websocket.Conn
}
func NewWS(c *websocket.Conn) (*WSPKT, error) {
w := &WSPKT{Conn: c}
return w, nil
}
func (t *WSPKT) ReadPacket() (n int, b []byte, err error) {
mt, msg, err := t.Conn.ReadMessage()
if err != nil {
return 0, []byte{0, 0}, err
}
if mt == websocket.BinaryMessage {
return len(msg), msg, nil
}
return len(msg), msg, errors.New("not a binary packet")
}
func (t *WSPKT) WritePacket(b []byte) (n int, err error) {
err = t.Conn.WriteMessage(websocket.BinaryMessage, b)
if err != nil {
return 0, err
}
return len(b), nil
}
func (t *WSPKT) Close() error {
return t.Conn.Close()
}