mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-27 22:46:37 +00:00
Add Makefile build
This prepares for multiple binaries and distribution builds
This commit is contained in:
40
cmd/rdpgw/api/token.go
Normal file
40
cmd/rdpgw/api/token.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (c *Config) TokenInfo(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Invalid request", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, ok := r.URL.Query()["access_token"]
|
||||
if !ok || len(tokens[0]) < 1 {
|
||||
log.Printf("Missing access_token in request")
|
||||
http.Error(w, "access_token missing in request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
token := tokens[0]
|
||||
|
||||
info, err := security.UserInfo(context.Background(), token)
|
||||
if err != nil {
|
||||
log.Printf("Token validation failed due to %s", err)
|
||||
http.Error(w, fmt.Sprintf("token validation failed due to %s", err), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
if err = json.NewEncoder(w).Encode(info); err != nil {
|
||||
log.Printf("Cannot encode json due to %s", err)
|
||||
http.Error(w, "cannot encode json", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
220
cmd/rdpgw/api/web.go
Normal file
220
cmd/rdpgw/api/web.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
RdpGwSession = "RDPGWSESSION"
|
||||
MaxAge = 120
|
||||
)
|
||||
|
||||
type TokenGeneratorFunc func(context.Context, string, string) (string, error)
|
||||
type UserTokenGeneratorFunc func(context.Context, string) (string, error)
|
||||
|
||||
type Config struct {
|
||||
SessionKey []byte
|
||||
SessionEncryptionKey []byte
|
||||
PAATokenGenerator TokenGeneratorFunc
|
||||
UserTokenGenerator UserTokenGeneratorFunc
|
||||
EnableUserToken bool
|
||||
OAuth2Config *oauth2.Config
|
||||
store *sessions.CookieStore
|
||||
OIDCTokenVerifier *oidc.IDTokenVerifier
|
||||
stateStore *cache.Cache
|
||||
Hosts []string
|
||||
GatewayAddress string
|
||||
UsernameTemplate string
|
||||
NetworkAutoDetect int
|
||||
BandwidthAutoDetect int
|
||||
ConnectionType int
|
||||
SplitUserDomain bool
|
||||
DefaultDomain string
|
||||
}
|
||||
|
||||
func (c *Config) NewApi() {
|
||||
if len(c.SessionKey) < 32 {
|
||||
log.Fatal("Session key too small")
|
||||
}
|
||||
if len(c.Hosts) < 1 {
|
||||
log.Fatal("Not enough hosts to connect to specified")
|
||||
}
|
||||
c.store = sessions.NewCookieStore(c.SessionKey, c.SessionEncryptionKey)
|
||||
c.stateStore = cache.New(time.Minute*2, 5*time.Minute)
|
||||
}
|
||||
|
||||
func (c *Config) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
state := r.URL.Query().Get("state")
|
||||
s, found := c.stateStore.Get(state)
|
||||
if !found {
|
||||
http.Error(w, "unknown state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := s.(string)
|
||||
|
||||
ctx := context.Background()
|
||||
oauth2Token, err := c.OAuth2Config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
idToken, err := c.OIDCTokenVerifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
OAuth2Token *oauth2.Token
|
||||
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
|
||||
}{oauth2Token, new(json.RawMessage)}
|
||||
|
||||
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := c.store.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
session.Options.MaxAge = MaxAge
|
||||
session.Values["preferred_username"] = data["preferred_username"]
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["access_token"] = oauth2Token.AccessToken
|
||||
|
||||
if err = session.Save(r, w); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func (c *Config) Authenticated(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := c.store.Get(r, RdpGwSession)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
found := session.Values["authenticated"]
|
||||
if found == nil || !found.(bool) {
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
state := hex.EncodeToString(seed)
|
||||
c.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration)
|
||||
http.Redirect(w, r, c.OAuth2Config.AuthCodeURL(state), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "preferred_username", session.Values["preferred_username"])
|
||||
ctx = context.WithValue(ctx, "access_token", session.Values["access_token"])
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userName, ok := ctx.Value("preferred_username").(string)
|
||||
|
||||
if !ok {
|
||||
log.Printf("preferred_username not found in context")
|
||||
http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// do a round robin selection for now
|
||||
rand.Seed(time.Now().Unix())
|
||||
host := c.Hosts[rand.Intn(len(c.Hosts))]
|
||||
host = strings.Replace(host, "{{ preferred_username }}", userName, 1)
|
||||
|
||||
// split the username into user and domain
|
||||
var user = userName
|
||||
var domain = c.DefaultDomain
|
||||
if c.SplitUserDomain {
|
||||
creds := strings.SplitN(userName, "@", 2)
|
||||
user = creds[0]
|
||||
if len(creds) > 1 {
|
||||
domain = creds[1]
|
||||
}
|
||||
}
|
||||
|
||||
render := user
|
||||
if c.UsernameTemplate != "" {
|
||||
render = fmt.Sprintf(c.UsernameTemplate)
|
||||
render = strings.Replace(render, "{{ username }}", user, 1)
|
||||
if c.UsernameTemplate == render {
|
||||
log.Printf("Invalid username template. %s == %s", c.UsernameTemplate, user)
|
||||
http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, err := c.PAATokenGenerator(ctx, user, host)
|
||||
if err != nil {
|
||||
log.Printf("Cannot generate PAA token for user %s due to %s", user, err)
|
||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if c.EnableUserToken {
|
||||
userToken, err := c.UserTokenGenerator(ctx, user)
|
||||
if err != nil {
|
||||
log.Printf("Cannot generate token for user %s due to %s", user, err)
|
||||
http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError)
|
||||
}
|
||||
render = strings.Replace(render, "{{ token }}", userToken, 1)
|
||||
}
|
||||
|
||||
// authenticated
|
||||
seed := make([]byte, 16)
|
||||
rand.Read(seed)
|
||||
fn := hex.EncodeToString(seed) + ".rdp"
|
||||
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+fn)
|
||||
w.Header().Set("Content-Type", "application/x-rdp")
|
||||
data := "full address:s:"+host+"\r\n"+
|
||||
"gatewayhostname:s:"+c.GatewayAddress+"\r\n"+
|
||||
"gatewaycredentialssource:i:5\r\n"+
|
||||
"gatewayusagemethod:i:1\r\n"+
|
||||
"gatewayprofileusagemethod:i:1\r\n"+
|
||||
"gatewayaccesstoken:s:"+token+"\r\n"+
|
||||
"networkautodetect:i:"+strconv.Itoa(c.NetworkAutoDetect)+"\r\n"+
|
||||
"bandwidthautodetect:i:"+strconv.Itoa(c.BandwidthAutoDetect)+"\r\n"+
|
||||
"connection type:i:"+strconv.Itoa(c.ConnectionType)+"\r\n"+
|
||||
"username:s:"+render+"\r\n"+
|
||||
"domain:s:"+domain+"\r\n"+
|
||||
"bitmapcachesize:i:32000\r\n"+
|
||||
"smart sizing:i:1\r\n"
|
||||
|
||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(data))
|
||||
}
|
||||
60
cmd/rdpgw/common/remote.go
Normal file
60
cmd/rdpgw/common/remote.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
ClientIPCtx = "ClientIP"
|
||||
ProxyAddressesCtx = "ProxyAddresses"
|
||||
RemoteAddressCtx = "RemoteAddress"
|
||||
)
|
||||
|
||||
func EnrichContext(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
h := r.Header.Get("X-Forwarded-For")
|
||||
if h != "" {
|
||||
var proxies []string
|
||||
ips := strings.Split(h, ",")
|
||||
for i := range ips {
|
||||
ips[i] = strings.TrimSpace(ips[i])
|
||||
}
|
||||
clientIp := ips[0]
|
||||
if len(ips) > 1 {
|
||||
proxies = ips[1:]
|
||||
}
|
||||
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
|
||||
ctx = context.WithValue(ctx, ProxyAddressesCtx, proxies)
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, RemoteAddressCtx, r.RemoteAddr)
|
||||
if h == "" {
|
||||
clientIp, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
ctx = context.WithValue(ctx, ClientIPCtx, clientIp)
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func GetClientIp(ctx context.Context) string {
|
||||
s, ok := ctx.Value(ClientIPCtx).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func GetAccessToken(ctx context.Context) string {
|
||||
token, ok := ctx.Value("access_token").(string)
|
||||
if !ok {
|
||||
log.Printf("cannot get access token from context")
|
||||
return ""
|
||||
}
|
||||
return token
|
||||
}
|
||||
97
cmd/rdpgw/config/configuration.go
Normal file
97
cmd/rdpgw/config/configuration.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
"log"
|
||||
)
|
||||
|
||||
type Configuration struct {
|
||||
Server ServerConfig
|
||||
OpenId OpenIDConfig
|
||||
Caps RDGCapsConfig
|
||||
Security SecurityConfig
|
||||
Client ClientConfig
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
GatewayAddress string
|
||||
Port int
|
||||
CertFile string
|
||||
KeyFile string
|
||||
Hosts []string
|
||||
RoundRobin bool
|
||||
SessionKey string
|
||||
SessionEncryptionKey string
|
||||
SendBuf int
|
||||
ReceiveBuf int
|
||||
}
|
||||
|
||||
type OpenIDConfig struct {
|
||||
ProviderUrl string
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
type RDGCapsConfig struct {
|
||||
SmartCardAuth bool
|
||||
TokenAuth bool
|
||||
IdleTimeout int
|
||||
RedirectAll bool
|
||||
DisableRedirect bool
|
||||
EnableClipboard bool
|
||||
EnablePrinter bool
|
||||
EnablePort bool
|
||||
EnablePnp bool
|
||||
EnableDrive bool
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
PAATokenEncryptionKey string
|
||||
PAATokenSigningKey string
|
||||
UserTokenEncryptionKey string
|
||||
UserTokenSigningKey string
|
||||
VerifyClientIp bool
|
||||
EnableUserToken bool
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
NetworkAutoDetect int
|
||||
BandwidthAutoDetect int
|
||||
ConnectionType int
|
||||
UsernameTemplate string
|
||||
SplitUserDomain bool
|
||||
DefaultDomain string
|
||||
}
|
||||
|
||||
func init() {
|
||||
viper.SetDefault("server.certFile", "server.pem")
|
||||
viper.SetDefault("server.keyFile", "key.pem")
|
||||
viper.SetDefault("server.port", 443)
|
||||
viper.SetDefault("client.networkAutoDetect", 1)
|
||||
viper.SetDefault("client.bandwidthAutoDetect", 1)
|
||||
viper.SetDefault("security.verifyClientIp", true)
|
||||
}
|
||||
|
||||
func Load(configFile string) Configuration {
|
||||
var conf Configuration
|
||||
|
||||
viper.SetConfigName("rdpgw")
|
||||
viper.SetConfigFile(configFile)
|
||||
viper.AddConfigPath(".")
|
||||
viper.SetEnvPrefix("RDPGW")
|
||||
viper.AutomaticEnv()
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
log.Fatalf("No config file found (%s)", err)
|
||||
}
|
||||
|
||||
if err := viper.Unmarshal(&conf); err != nil {
|
||||
log.Fatalf("Cannot unmarshal the config file; %s", err)
|
||||
}
|
||||
|
||||
if len(conf.Security.PAATokenSigningKey) < 32 {
|
||||
log.Fatalf("Token signing key not long enough")
|
||||
}
|
||||
|
||||
return conf
|
||||
}
|
||||
148
cmd/rdpgw/main.go
Normal file
148
cmd/rdpgw/main.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/api"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/oauth2"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var cmd = &cobra.Command{
|
||||
Use: "rdpgw",
|
||||
Long: "Remote Desktop Gateway",
|
||||
}
|
||||
|
||||
var (
|
||||
configFile string
|
||||
)
|
||||
|
||||
var conf config.Configuration
|
||||
|
||||
func main() {
|
||||
// get config
|
||||
cmd.PersistentFlags().StringVarP(&configFile, "conf", "c", "rdpgw.yaml", "config file (json, yaml, ini)")
|
||||
conf = config.Load(configFile)
|
||||
|
||||
security.VerifyClientIP = conf.Security.VerifyClientIp
|
||||
|
||||
// set security keys
|
||||
security.SigningKey = []byte(conf.Security.PAATokenSigningKey)
|
||||
security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey)
|
||||
security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey)
|
||||
security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey)
|
||||
|
||||
// set oidc config
|
||||
provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl)
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot get oidc provider: %s", err)
|
||||
}
|
||||
oidcConfig := &oidc.Config{
|
||||
ClientID: conf.OpenId.ClientId,
|
||||
}
|
||||
verifier := provider.Verifier(oidcConfig)
|
||||
|
||||
oauthConfig := oauth2.Config{
|
||||
ClientID: conf.OpenId.ClientId,
|
||||
ClientSecret: conf.OpenId.ClientSecret,
|
||||
RedirectURL: "https://" + conf.Server.GatewayAddress + "/callback",
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
security.OIDCProvider = provider
|
||||
security.Oauth2Config = oauthConfig
|
||||
|
||||
api := &api.Config{
|
||||
GatewayAddress: conf.Server.GatewayAddress,
|
||||
OAuth2Config: &oauthConfig,
|
||||
OIDCTokenVerifier: verifier,
|
||||
PAATokenGenerator: security.GeneratePAAToken,
|
||||
UserTokenGenerator: security.GenerateUserToken,
|
||||
EnableUserToken: conf.Security.EnableUserToken,
|
||||
SessionKey: []byte(conf.Server.SessionKey),
|
||||
SessionEncryptionKey: []byte(conf.Server.SessionEncryptionKey),
|
||||
Hosts: conf.Server.Hosts,
|
||||
NetworkAutoDetect: conf.Client.NetworkAutoDetect,
|
||||
UsernameTemplate: conf.Client.UsernameTemplate,
|
||||
BandwidthAutoDetect: conf.Client.BandwidthAutoDetect,
|
||||
ConnectionType: conf.Client.ConnectionType,
|
||||
SplitUserDomain: conf.Client.SplitUserDomain,
|
||||
DefaultDomain: conf.Client.DefaultDomain,
|
||||
}
|
||||
api.NewApi()
|
||||
|
||||
if conf.Server.CertFile == "" || conf.Server.KeyFile == "" {
|
||||
log.Fatal("Both certfile and keyfile need to be specified")
|
||||
}
|
||||
|
||||
//mux := http.NewServeMux()
|
||||
//mux.HandleFunc("*", HelloServer)
|
||||
|
||||
log.Printf("Starting remote desktop gateway server")
|
||||
|
||||
cfg := &tls.Config{}
|
||||
tlsDebug := os.Getenv("SSLKEYLOGFILE")
|
||||
if tlsDebug != "" {
|
||||
w, err := os.OpenFile(tlsDebug, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
log.Fatalf("Cannot open key log file %s for writing %s", tlsDebug, err)
|
||||
}
|
||||
log.Printf("Key log file set to: %s", tlsDebug)
|
||||
cfg.KeyLogWriter = w
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
cfg.Certificates = append(cfg.Certificates, cert)
|
||||
server := http.Server{
|
||||
Addr: ":" + strconv.Itoa(conf.Server.Port),
|
||||
TLSConfig: cfg,
|
||||
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
|
||||
}
|
||||
|
||||
// create the gateway
|
||||
handlerConfig := protocol.ServerConf{
|
||||
IdleTimeout: conf.Caps.IdleTimeout,
|
||||
TokenAuth: conf.Caps.TokenAuth,
|
||||
SmartCardAuth: conf.Caps.SmartCardAuth,
|
||||
RedirectFlags: protocol.RedirectFlags{
|
||||
Clipboard: conf.Caps.EnableClipboard,
|
||||
Drive: conf.Caps.EnableDrive,
|
||||
Printer: conf.Caps.EnablePrinter,
|
||||
Port: conf.Caps.EnablePort,
|
||||
Pnp: conf.Caps.EnablePnp,
|
||||
DisableAll: conf.Caps.DisableRedirect,
|
||||
EnableAll: conf.Caps.RedirectAll,
|
||||
},
|
||||
VerifyTunnelCreate: security.VerifyPAAToken,
|
||||
VerifyServerFunc: security.VerifyServerFunc,
|
||||
SendBuf: conf.Server.SendBuf,
|
||||
ReceiveBuf: conf.Server.ReceiveBuf,
|
||||
}
|
||||
gw := protocol.Gateway{
|
||||
ServerConf: &handlerConfig,
|
||||
}
|
||||
|
||||
http.Handle("/remoteDesktopGateway/", common.EnrichContext(http.HandlerFunc(gw.HandleGatewayProtocol)))
|
||||
http.Handle("/connect", common.EnrichContext(api.Authenticated(http.HandlerFunc(api.HandleDownload))))
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
http.HandleFunc("/tokeninfo", api.TokenInfo)
|
||||
http.HandleFunc("/callback", api.HandleCallback)
|
||||
|
||||
err = server.ListenAndServeTLS("", "")
|
||||
if err != nil {
|
||||
log.Fatal("ListenAndServe: ", err)
|
||||
}
|
||||
}
|
||||
246
cmd/rdpgw/protocol/client.go
Normal file
246
cmd/rdpgw/protocol/client.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
MajorVersion = 0x0
|
||||
MinorVersion = 0x0
|
||||
Version = 0x00
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
SmartCardAuth bool
|
||||
PAAToken string
|
||||
NTLMAuth bool
|
||||
Session *SessionInfo
|
||||
LocalConn net.Conn
|
||||
Server string
|
||||
Port int
|
||||
Name string
|
||||
}
|
||||
|
||||
func (c *ClientConfig) ConnectAndForward() error {
|
||||
c.Session.TransportOut.WritePacket(c.handshakeRequest())
|
||||
|
||||
for {
|
||||
pt, sz, pkt, err := readMessage(c.Session.TransportIn)
|
||||
if err != nil {
|
||||
log.Printf("Cannot read message from stream %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_RESPONSE:
|
||||
caps, err := c.handshakeResponse(pkt)
|
||||
if err != nil {
|
||||
log.Printf("Cannot connect to %s due to %s", c.Server, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("Handshake response received. Caps: %d", caps)
|
||||
c.Session.TransportOut.WritePacket(c.tunnelRequest())
|
||||
case PKT_TYPE_TUNNEL_RESPONSE:
|
||||
tid, caps, err := c.tunnelResponse(pkt)
|
||||
if err != nil {
|
||||
log.Printf("Cannot setup tunnel due to %s", err)
|
||||
return err
|
||||
}
|
||||
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
|
||||
c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
|
||||
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
|
||||
flags, timeout, err := c.tunnelAuthResponse(pkt)
|
||||
if err != nil {
|
||||
log.Printf("Cannot do tunnel auth due to %s", err)
|
||||
return err
|
||||
}
|
||||
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
|
||||
c.Session.TransportOut.WritePacket(c.channelRequest())
|
||||
case PKT_TYPE_CHANNEL_RESPONSE:
|
||||
cid, err := c.channelResponse(pkt)
|
||||
if err != nil {
|
||||
log.Printf("Cannot do tunnel auth due to %s", err)
|
||||
return err
|
||||
}
|
||||
if cid < 1 {
|
||||
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
|
||||
}
|
||||
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
||||
go forward(c.LocalConn, c.Session.TransportOut)
|
||||
case PKT_TYPE_DATA:
|
||||
receive(pkt, c.LocalConn)
|
||||
default:
|
||||
log.Printf("Unknown packet type received: %d size %d", pt, sz)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConfig) handshakeRequest() []byte {
|
||||
var caps uint16
|
||||
|
||||
if c.SmartCardAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_SC
|
||||
}
|
||||
|
||||
if len(c.PAAToken) > 0 {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||
}
|
||||
|
||||
if c.NTLMAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_SSPI_NTLM
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, byte(MajorVersion))
|
||||
binary.Write(buf, binary.LittleEndian, byte(MinorVersion))
|
||||
binary.Write(buf, binary.LittleEndian, uint16(Version))
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint16(caps))
|
||||
|
||||
return createPacket(PKT_TYPE_HANDSHAKE_REQUEST, buf.Bytes())
|
||||
}
|
||||
|
||||
func (c *ClientConfig) handshakeResponse(data []byte) (caps uint16, err error) {
|
||||
var errorCode int32
|
||||
var major byte
|
||||
var minor byte
|
||||
var version uint16
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &errorCode)
|
||||
binary.Read(r, binary.LittleEndian, &major)
|
||||
binary.Read(r, binary.LittleEndian, &minor)
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
binary.Read(r, binary.LittleEndian, &caps)
|
||||
|
||||
if errorCode > 0 {
|
||||
return 0, fmt.Errorf("error code: %d", errorCode)
|
||||
}
|
||||
|
||||
return caps, nil
|
||||
}
|
||||
|
||||
func (c *ClientConfig) tunnelRequest() []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
var caps uint32
|
||||
var size uint16
|
||||
var fields uint16
|
||||
|
||||
if len(c.PAAToken) > 0 {
|
||||
fields = fields | HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE
|
||||
}
|
||||
|
||||
caps = caps | HTTP_CAPABILITY_IDLE_TIMEOUT
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, caps)
|
||||
binary.Write(buf, binary.LittleEndian, fields)
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
|
||||
if len(c.PAAToken) > 0 {
|
||||
utf16Token := EncodeUTF16(c.PAAToken)
|
||||
size = uint16(len(utf16Token))
|
||||
binary.Write(buf, binary.LittleEndian, size)
|
||||
buf.Write(utf16Token)
|
||||
}
|
||||
|
||||
return createPacket(PKT_TYPE_TUNNEL_CREATE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32, err error) {
|
||||
var version uint16
|
||||
var errorCode uint32
|
||||
var fields uint16
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
binary.Read(r, binary.LittleEndian, &errorCode)
|
||||
binary.Read(r, binary.LittleEndian, &fields)
|
||||
r.Seek(2, io.SeekCurrent)
|
||||
if (fields & HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID) == HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID {
|
||||
binary.Read(r, binary.LittleEndian, &tunnelId)
|
||||
}
|
||||
if (fields & HTTP_TUNNEL_RESPONSE_FIELD_CAPS) == HTTP_TUNNEL_RESPONSE_FIELD_CAPS {
|
||||
binary.Read(r, binary.LittleEndian, &caps)
|
||||
}
|
||||
|
||||
if errorCode != 0 {
|
||||
err = fmt.Errorf("tunnel error %d", errorCode)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientConfig) tunnelAuthRequest() []byte {
|
||||
utf16name := EncodeUTF16(c.Name)
|
||||
size := uint16(len(utf16name))
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, size)
|
||||
buf.Write(utf16name)
|
||||
|
||||
return createPacket(PKT_TYPE_TUNNEL_AUTH, buf.Bytes())
|
||||
}
|
||||
|
||||
func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout uint32, err error) {
|
||||
var errorCode uint32
|
||||
var fields uint16
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &errorCode)
|
||||
binary.Read(r, binary.LittleEndian, &fields)
|
||||
r.Seek(2, io.SeekCurrent)
|
||||
|
||||
if (fields & HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS) == HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS {
|
||||
binary.Read(r, binary.LittleEndian, &flags)
|
||||
}
|
||||
if (fields & HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT) == HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT {
|
||||
binary.Read(r, binary.LittleEndian, &timeout)
|
||||
}
|
||||
|
||||
if errorCode > 0 {
|
||||
return 0, 0, fmt.Errorf("tunnel auth error %d", errorCode)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientConfig) channelRequest() []byte {
|
||||
utf16server := EncodeUTF16(c.Server)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
|
||||
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
|
||||
binary.Write(buf, binary.LittleEndian, uint16(c.Port))
|
||||
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
|
||||
buf.Write(utf16server)
|
||||
|
||||
return createPacket(PKT_TYPE_CHANNEL_CREATE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (c *ClientConfig) channelResponse(data []byte) (channelId uint32, err error) {
|
||||
var errorCode uint32
|
||||
var fields uint16
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &errorCode)
|
||||
binary.Read(r, binary.LittleEndian, &fields)
|
||||
r.Seek(2, io.SeekCurrent)
|
||||
|
||||
if (fields & HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID) == HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID {
|
||||
binary.Read(r, binary.LittleEndian, &channelId)
|
||||
}
|
||||
|
||||
if errorCode > 0 {
|
||||
return 0, fmt.Errorf("channel response error %d", errorCode)
|
||||
}
|
||||
|
||||
return channelId, nil
|
||||
}
|
||||
148
cmd/rdpgw/protocol/common.go
Normal file
148
cmd/rdpgw/protocol/common.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type RedirectFlags struct {
|
||||
Clipboard bool
|
||||
Port bool
|
||||
Drive bool
|
||||
Printer bool
|
||||
Pnp bool
|
||||
DisableAll bool
|
||||
EnableAll bool
|
||||
}
|
||||
|
||||
type SessionInfo struct {
|
||||
// The connection-id (RDG-ConnID) as reported by the client
|
||||
ConnId string
|
||||
// The underlying incoming transport being either websocket or legacy http
|
||||
// in case of websocket TransportOut will equal TransportIn
|
||||
TransportIn transport.Transport
|
||||
// The underlying outgoing transport being either websocket or legacy http
|
||||
// in case of websocket TransportOut will equal TransportOut
|
||||
TransportOut transport.Transport
|
||||
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
|
||||
RemoteServer string
|
||||
// The obtained client ip address
|
||||
ClientIp string
|
||||
}
|
||||
|
||||
// readMessage parses and defragments a packet from a Transport. It returns
|
||||
// at most the bytes that have been reported by the packet
|
||||
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
|
||||
fragment := false
|
||||
index := 0
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
size, pkt, err := in.ReadPacket()
|
||||
if err != nil {
|
||||
return 0, 0, []byte{0, 0}, err
|
||||
}
|
||||
|
||||
// check for fragments
|
||||
var pt uint16
|
||||
var sz uint32
|
||||
var msg []byte
|
||||
|
||||
if !fragment {
|
||||
pt, sz, msg, err = readHeader(pkt[:size])
|
||||
if err != nil {
|
||||
fragment = true
|
||||
index = copy(buf, pkt[:size])
|
||||
continue
|
||||
}
|
||||
index = 0
|
||||
} else {
|
||||
fragment = false
|
||||
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
|
||||
// header is corrupted even after defragmenting
|
||||
if err != nil {
|
||||
return 0, 0, []byte{0, 0}, err
|
||||
}
|
||||
}
|
||||
if !fragment {
|
||||
return int(pt), int(sz), msg, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createPacket wraps the data into the protocol packet
|
||||
func createPacket(pktType uint16, data []byte) (packet []byte) {
|
||||
size := len(data) + 8
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint16(pktType))
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
binary.Write(buf, binary.LittleEndian, uint32(size))
|
||||
buf.Write(data)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// readHeader parses a packet and verifies its reported size
|
||||
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
|
||||
// header needs to be 8 min
|
||||
if len(data) < 8 {
|
||||
return 0, 0, nil, errors.New("header too short, fragment likely")
|
||||
}
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &packetType)
|
||||
r.Seek(4, io.SeekStart)
|
||||
binary.Read(r, binary.LittleEndian, &size)
|
||||
if len(data) < int(size) {
|
||||
return packetType, size, data[8:], errors.New("data incomplete, fragment received")
|
||||
}
|
||||
return packetType, size, data[8:size], nil
|
||||
}
|
||||
|
||||
// forwards data from a Connection to Transport and wraps it in the rdpgw protocol
|
||||
func forward(in net.Conn, out transport.Transport) {
|
||||
defer in.Close()
|
||||
|
||||
b1 := new(bytes.Buffer)
|
||||
buf := make([]byte, 4086)
|
||||
|
||||
for {
|
||||
n, err := in.Read(buf)
|
||||
if err != nil {
|
||||
log.Printf("Error reading from local conn %s", err)
|
||||
break
|
||||
}
|
||||
binary.Write(b1, binary.LittleEndian, uint16(n))
|
||||
b1.Write(buf[:n])
|
||||
out.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
b1.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// receive data received from the gateway client, unwrap and forward the remote desktop server
|
||||
func receive(data []byte, out net.Conn) {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var cblen uint16
|
||||
binary.Read(buf, binary.LittleEndian, &cblen)
|
||||
pkt := make([]byte, cblen)
|
||||
binary.Read(buf, binary.LittleEndian, &pkt)
|
||||
|
||||
out.Write(pkt)
|
||||
}
|
||||
|
||||
// wrapSyscallError takes an error and a syscall name. If the error is
|
||||
// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name.
|
||||
func wrapSyscallError(name string, err error) error {
|
||||
if _, ok := err.(syscall.Errno); ok {
|
||||
err = os.NewSyscallError(name, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
213
cmd/rdpgw/protocol/gateway.go
Normal file
213
cmd/rdpgw/protocol/gateway.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
rdgConnectionIdKey = "Rdg-Connection-Id"
|
||||
MethodRDGIN = "RDG_IN_DATA"
|
||||
MethodRDGOUT = "RDG_OUT_DATA"
|
||||
)
|
||||
|
||||
var (
|
||||
connectionCache = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "rdpgw",
|
||||
Name: "connection_cache",
|
||||
Help: "The amount of connections in the cache",
|
||||
})
|
||||
|
||||
websocketConnections = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "rdpgw",
|
||||
Name: "websocket_connections",
|
||||
Help: "The count of websocket connections",
|
||||
})
|
||||
|
||||
legacyConnections = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "rdpgw",
|
||||
Name: "legacy_connections",
|
||||
Help: "The count of legacy https connections",
|
||||
})
|
||||
)
|
||||
|
||||
type Gateway struct {
|
||||
ServerConf *ServerConf
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(connectionCache)
|
||||
prometheus.MustRegister(legacyConnections)
|
||||
prometheus.MustRegister(websocketConnections)
|
||||
}
|
||||
|
||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
connectionCache.Set(float64(c.ItemCount()))
|
||||
|
||||
var s *SessionInfo
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
s = &SessionInfo{ConnId: connId}
|
||||
} else {
|
||||
s = x.(*SessionInfo)
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), "SessionInfo", s)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||
return
|
||||
}
|
||||
r.Method = "GET" // force
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("Cannot upgrade falling back to old protocol: %s", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
|
||||
if err != nil {
|
||||
log.Printf("Cannot set send/receive buffers: %s", err)
|
||||
}
|
||||
|
||||
g.handleWebsocketProtocol(ctx, conn, s)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
||||
if g.ServerConf.SendBuf < 1 && g.ServerConf.ReceiveBuf < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// conn == tls.Conn
|
||||
ptr := reflect.ValueOf(conn)
|
||||
val := reflect.Indirect(ptr)
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return errors.New("didn't get a struct from conn")
|
||||
}
|
||||
|
||||
// this gets net.Conn -> *net.TCPConn -> net.TCPConn
|
||||
ptrConn := val.FieldByName("conn")
|
||||
valConn := reflect.Indirect(ptrConn)
|
||||
if !valConn.IsValid() {
|
||||
return errors.New("cannot find conn field")
|
||||
}
|
||||
valConn = valConn.Elem().Elem()
|
||||
|
||||
// net.FD
|
||||
ptrNetFd := valConn.FieldByName("fd")
|
||||
valNetFd := reflect.Indirect(ptrNetFd)
|
||||
if !valNetFd.IsValid() {
|
||||
return errors.New("cannot find fd field")
|
||||
}
|
||||
|
||||
// pfd member
|
||||
ptrPfd := valNetFd.FieldByName("pfd")
|
||||
valPfd := reflect.Indirect(ptrPfd)
|
||||
if !valPfd.IsValid() {
|
||||
return errors.New("cannot find pfd field")
|
||||
}
|
||||
|
||||
// finally the exported Sysfd
|
||||
ptrSysFd := valPfd.FieldByName("Sysfd")
|
||||
if !ptrSysFd.IsValid() {
|
||||
return errors.New("cannot find Sysfd field")
|
||||
}
|
||||
fd := int(ptrSysFd.Int())
|
||||
|
||||
if g.ServerConf.ReceiveBuf > 0 {
|
||||
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ServerConf.ReceiveBuf)
|
||||
if err != nil {
|
||||
return wrapSyscallError("setsockopt", err)
|
||||
}
|
||||
}
|
||||
|
||||
if g.ServerConf.SendBuf > 0 {
|
||||
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.ServerConf.SendBuf)
|
||||
if err != nil {
|
||||
return wrapSyscallError("setsockopt", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
|
||||
websocketConnections.Inc()
|
||||
defer websocketConnections.Dec()
|
||||
|
||||
inout, _ := transport.NewWS(c)
|
||||
s.TransportOut = inout
|
||||
s.TransportIn = inout
|
||||
handler := NewServer(s, g.ServerConf)
|
||||
handler.Process(ctx)
|
||||
}
|
||||
|
||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
|
||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) {
|
||||
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
out, err := transport.NewLegacy(w)
|
||||
if err != nil {
|
||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
|
||||
|
||||
s.TransportOut = out
|
||||
out.SendAccept(true)
|
||||
|
||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
legacyConnections.Inc()
|
||||
defer legacyConnections.Dec()
|
||||
|
||||
in, err := transport.NewLegacy(w)
|
||||
if err != nil {
|
||||
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
|
||||
return
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
if s.TransportIn == nil {
|
||||
s.TransportIn = in
|
||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||
|
||||
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
|
||||
in.SendAccept(false)
|
||||
|
||||
// read some initial data
|
||||
in.Drain()
|
||||
|
||||
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
|
||||
handler := NewServer(s, g.ServerConf)
|
||||
handler.Process(r.Context())
|
||||
}
|
||||
}
|
||||
}
|
||||
211
cmd/rdpgw/protocol/protocol_test.go
Normal file
211
cmd/rdpgw/protocol/protocol_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderLen = 8
|
||||
HandshakeRequestLen = HeaderLen + 6
|
||||
HandshakeResponseLen = HeaderLen + 10
|
||||
TunnelCreateRequestLen = HeaderLen + 8 // + dynamic
|
||||
TunnelCreateResponseLen = HeaderLen + 18
|
||||
TunnelAuthLen = HeaderLen + 2 // + dynamic
|
||||
TunnelAuthResponseLen = HeaderLen + 16
|
||||
ChannelCreateLen = HeaderLen + 8 // + dynamic
|
||||
ChannelResponseLen = HeaderLen + 12
|
||||
)
|
||||
|
||||
func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) {
|
||||
pt, size, pkt, err := readHeader(data)
|
||||
|
||||
if pt != expPt {
|
||||
return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt)
|
||||
}
|
||||
|
||||
if size != expSize {
|
||||
return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected size %d, got %d", expSize, size)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, 0, []byte{}, err
|
||||
}
|
||||
|
||||
return pt, size, pkt, nil
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
client := ClientConfig{
|
||||
PAAToken: "abab",
|
||||
}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
TokenAuth: true,
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
|
||||
data := client.handshakeRequest()
|
||||
|
||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_REQUEST, HandshakeRequestLen)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
|
||||
log.Printf("pkt: %x", pkt)
|
||||
|
||||
major, minor, version, extAuth := h.handshakeRequest(pkt)
|
||||
if major != MajorVersion || minor != MinorVersion || version != Version {
|
||||
t.Fatalf("handshakeRequest failed got version %d.%d protocol %d, expected %d.%d protocol %d",
|
||||
major, minor, version, MajorVersion, MinorVersion, Version)
|
||||
}
|
||||
|
||||
if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
||||
t.Fatalf("handshakeRequest failed got ext auth %d, expected %d", extAuth, extAuth|HTTP_EXTENDED_AUTH_PAA)
|
||||
}
|
||||
|
||||
data = h.handshakeResponse(0x0, 0x0)
|
||||
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen)
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
log.Printf("pkt: %x", pkt)
|
||||
|
||||
caps, err := client.handshakeResponse(pkt)
|
||||
if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) {
|
||||
t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps|HTTP_EXTENDED_AUTH_PAA)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunnelCreation(t *testing.T) {
|
||||
client := ClientConfig{
|
||||
PAAToken: "abab",
|
||||
}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
TokenAuth: true,
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
|
||||
data := client.tunnelRequest()
|
||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE,
|
||||
uint32(TunnelCreateRequestLen+2+len(client.PAAToken)*2))
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
|
||||
caps, token := h.tunnelRequest(pkt)
|
||||
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
||||
t.Fatalf("tunnelRequest failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT)
|
||||
}
|
||||
if token != client.PAAToken {
|
||||
t.Fatalf("tunnelRequest failed got token %s, expected %s", token, client.PAAToken)
|
||||
}
|
||||
|
||||
data = h.tunnelResponse()
|
||||
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen)
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
|
||||
tid, caps, err := client.tunnelResponse(pkt)
|
||||
if err != nil {
|
||||
t.Fatalf("Error %s", err)
|
||||
}
|
||||
if tid != tunnelId {
|
||||
t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId)
|
||||
}
|
||||
if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) {
|
||||
t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunnelAuth(t *testing.T) {
|
||||
name := "test_name"
|
||||
client := ClientConfig{
|
||||
Name: name,
|
||||
}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
TokenAuth: true,
|
||||
IdleTimeout: 10,
|
||||
RedirectFlags: RedirectFlags{
|
||||
Clipboard: true,
|
||||
},
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
|
||||
data := client.tunnelAuthRequest()
|
||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2))
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
|
||||
n := h.tunnelAuthRequest(pkt)
|
||||
if n != name {
|
||||
t.Fatalf("tunnelAuthRequest failed got name %s, expected %s", n, name)
|
||||
}
|
||||
|
||||
data = h.tunnelAuthResponse()
|
||||
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH_RESPONSE, TunnelAuthResponseLen)
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
flags, timeout, err := client.tunnelAuthResponse(pkt)
|
||||
if err != nil {
|
||||
t.Fatalf("tunnel auth error %s", err)
|
||||
}
|
||||
if (flags & HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) == HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD {
|
||||
t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d",
|
||||
flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD)
|
||||
}
|
||||
if int(timeout) != hc.IdleTimeout {
|
||||
t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d",
|
||||
timeout, hc.IdleTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelCreation(t *testing.T) {
|
||||
server := "test_server"
|
||||
client := ClientConfig{
|
||||
Server: server,
|
||||
Port: 3389,
|
||||
}
|
||||
s := &SessionInfo{}
|
||||
hc := &ServerConf{
|
||||
TokenAuth: true,
|
||||
IdleTimeout: 10,
|
||||
RedirectFlags: RedirectFlags{
|
||||
Clipboard: true,
|
||||
},
|
||||
}
|
||||
h := NewServer(s, hc)
|
||||
|
||||
data := client.channelRequest()
|
||||
_, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2))
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
hServer, hPort := h.channelRequest(pkt)
|
||||
if hServer != server {
|
||||
t.Fatalf("channelRequest failed got server %s, expected %s", hServer, server)
|
||||
}
|
||||
if int(hPort) != client.Port {
|
||||
t.Fatalf("channelRequest failed got port %d, expected %d", hPort, client.Port)
|
||||
}
|
||||
|
||||
data = h.channelResponse()
|
||||
_, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_CHANNEL_RESPONSE, uint32(ChannelResponseLen))
|
||||
if err != nil {
|
||||
t.Fatalf("verifyHeader failed: %s", err)
|
||||
}
|
||||
channelId, err := client.channelResponse(pkt)
|
||||
if err != nil {
|
||||
t.Fatalf("channelResponse failed: %s", err)
|
||||
}
|
||||
if channelId < 1 {
|
||||
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
|
||||
}
|
||||
}
|
||||
340
cmd/rdpgw/protocol/server.go
Normal file
340
cmd/rdpgw/protocol/server.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VerifyTunnelCreate func(context.Context, string) (bool, error)
|
||||
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
|
||||
type VerifyServerFunc func(context.Context, string) (bool, error)
|
||||
|
||||
type Server struct {
|
||||
Session *SessionInfo
|
||||
VerifyTunnelCreate VerifyTunnelCreate
|
||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||
VerifyServerFunc VerifyServerFunc
|
||||
RedirectFlags int
|
||||
IdleTimeout int
|
||||
SmartCardAuth bool
|
||||
TokenAuth bool
|
||||
ClientName string
|
||||
Remote net.Conn
|
||||
State int
|
||||
}
|
||||
|
||||
type ServerConf struct {
|
||||
VerifyTunnelCreate VerifyTunnelCreate
|
||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||
VerifyServerFunc VerifyServerFunc
|
||||
RedirectFlags RedirectFlags
|
||||
IdleTimeout int
|
||||
SmartCardAuth bool
|
||||
TokenAuth bool
|
||||
ReceiveBuf int
|
||||
SendBuf int
|
||||
}
|
||||
|
||||
func NewServer(s *SessionInfo, conf *ServerConf) *Server {
|
||||
h := &Server{
|
||||
State: SERVER_STATE_INITIAL,
|
||||
Session: s,
|
||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||
IdleTimeout: conf.IdleTimeout,
|
||||
SmartCardAuth: conf.SmartCardAuth,
|
||||
TokenAuth: conf.TokenAuth,
|
||||
VerifyTunnelCreate: conf.VerifyTunnelCreate,
|
||||
VerifyServerFunc: conf.VerifyServerFunc,
|
||||
VerifyTunnelAuthFunc: conf.VerifyTunnelAuthFunc,
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
const tunnelId = 10
|
||||
|
||||
func (s *Server) Process(ctx context.Context) error {
|
||||
for {
|
||||
pt, sz, pkt, err := readMessage(s.Session.TransportIn)
|
||||
if err != nil {
|
||||
log.Printf("Cannot read message from stream %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
|
||||
if s.State != SERVER_STATE_INITIAL {
|
||||
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
major, minor, _, _ := s.handshakeRequest(pkt) // todo check if auth matches what the handler can do
|
||||
msg := s.handshakeResponse(major, minor)
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_HANDSHAKE
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
log.Printf("Tunnel create")
|
||||
if s.State != SERVER_STATE_HANDSHAKE {
|
||||
log.Printf("Tunnel create attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_HANDSHAKE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
_, cookie := s.tunnelRequest(pkt)
|
||||
if s.VerifyTunnelCreate != nil {
|
||||
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
|
||||
return errors.New("invalid PAA cookie")
|
||||
}
|
||||
}
|
||||
msg := s.tunnelResponse()
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_TUNNEL_CREATE
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
log.Printf("Tunnel auth")
|
||||
if s.State != SERVER_STATE_TUNNEL_CREATE {
|
||||
log.Printf("Tunnel auth attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_TUNNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
client := s.tunnelAuthRequest(pkt)
|
||||
if s.VerifyTunnelAuthFunc != nil {
|
||||
if ok, _ := s.VerifyTunnelAuthFunc(ctx, client); !ok {
|
||||
log.Printf("Invalid client name: %s", client)
|
||||
return errors.New("invalid client name")
|
||||
}
|
||||
}
|
||||
msg := s.tunnelAuthResponse()
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
s.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
log.Printf("Channel create")
|
||||
if s.State != SERVER_STATE_TUNNEL_AUTHORIZE {
|
||||
log.Printf("Channel create attempted while in wrong state %d != %d",
|
||||
s.State, SERVER_STATE_TUNNEL_AUTHORIZE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
server, port := s.channelRequest(pkt)
|
||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||
if s.VerifyServerFunc != nil {
|
||||
if ok, _ := s.VerifyServerFunc(ctx, host); !ok {
|
||||
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||
return errors.New("denied by security policy")
|
||||
}
|
||||
}
|
||||
log.Printf("Establishing connection to RDP server: %s", host)
|
||||
s.Remote, err = net.DialTimeout("tcp", host, time.Second*15)
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to %s, %s", host, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("Connection established")
|
||||
msg := s.channelResponse()
|
||||
s.Session.TransportOut.WritePacket(msg)
|
||||
|
||||
// Make sure to start the flow from the RDP server first otherwise connections
|
||||
// might hang eventually
|
||||
go forward(s.Remote, s.Session.TransportOut)
|
||||
s.State = SERVER_STATE_CHANNEL_CREATE
|
||||
case PKT_TYPE_DATA:
|
||||
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Data received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
s.State = SERVER_STATE_OPENED
|
||||
receive(pkt, s.Remote)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
// keepalives can be received while the channel is not open yet
|
||||
if s.State < SERVER_STATE_CHANNEL_CREATE {
|
||||
log.Printf("Keepalive received while in wrong state %d != %d", s.State, SERVER_STATE_CHANNEL_CREATE)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
|
||||
// avoid concurrency issues
|
||||
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
case PKT_TYPE_CLOSE_CHANNEL:
|
||||
log.Printf("Close channel")
|
||||
if s.State != SERVER_STATE_OPENED {
|
||||
log.Printf("Channel closed while in wrong state %d != %d", s.State, SERVER_STATE_OPENED)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
s.Session.TransportIn.Close()
|
||||
s.Session.TransportOut.Close()
|
||||
s.State = SERVER_STATE_CLOSED
|
||||
default:
|
||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a packet the is a response to a handshakeRequest request
|
||||
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
|
||||
// but could be in Windows. However the NTLM protocol is insecure
|
||||
func (s *Server) handshakeResponse(major byte, minor byte) []byte {
|
||||
var caps uint16
|
||||
if s.SmartCardAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_SC
|
||||
}
|
||||
if s.TokenAuth {
|
||||
caps = caps | HTTP_EXTENDED_AUTH_PAA
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error_code
|
||||
buf.Write([]byte{major, minor})
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
|
||||
binary.Write(buf, binary.LittleEndian, uint16(caps)) // extended auth
|
||||
|
||||
return createPacket(PKT_TYPE_HANDSHAKE_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) handshakeRequest(data []byte) (major byte, minor byte, version uint16, extAuth uint16) {
|
||||
r := bytes.NewReader(data)
|
||||
binary.Read(r, binary.LittleEndian, &major)
|
||||
binary.Read(r, binary.LittleEndian, &minor)
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
binary.Read(r, binary.LittleEndian, &extAuth)
|
||||
|
||||
log.Printf("major: %d, minor: %d, version: %d, ext auth: %d", major, minor, version, extAuth)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) tunnelRequest(data []byte) (caps uint32, cookie string) {
|
||||
var fields uint16
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
binary.Read(r, binary.LittleEndian, &caps)
|
||||
binary.Read(r, binary.LittleEndian, &fields)
|
||||
r.Seek(2, io.SeekCurrent)
|
||||
|
||||
if fields == HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE {
|
||||
var size uint16
|
||||
binary.Read(r, binary.LittleEndian, &size)
|
||||
cookieB := make([]byte, size)
|
||||
r.Read(cookieB)
|
||||
cookie, _ = DecodeUTF16(cookieB)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) tunnelResponse() []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // server version
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
||||
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID|HTTP_TUNNEL_RESPONSE_FIELD_CAPS)) // fields present
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
|
||||
// tunnel id (when is it used?)
|
||||
binary.Write(buf, binary.LittleEndian, uint32(tunnelId))
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(HTTP_CAPABILITY_IDLE_TIMEOUT))
|
||||
|
||||
return createPacket(PKT_TYPE_TUNNEL_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) tunnelAuthRequest(data []byte) string {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var size uint16
|
||||
binary.Read(buf, binary.LittleEndian, &size)
|
||||
clData := make([]byte, size)
|
||||
binary.Read(buf, binary.LittleEndian, &clData)
|
||||
clientName, _ := DecodeUTF16(clData)
|
||||
|
||||
return clientName
|
||||
}
|
||||
|
||||
func (s *Server) tunnelAuthResponse() []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
||||
binary.Write(buf, binary.LittleEndian, uint16(HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS|HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT)) // fields present
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
|
||||
// idle timeout
|
||||
if s.IdleTimeout < 0 {
|
||||
s.IdleTimeout = 0
|
||||
}
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(s.RedirectFlags)) // redir flags
|
||||
binary.Write(buf, binary.LittleEndian, uint32(s.IdleTimeout)) // timeout in minutes
|
||||
|
||||
return createPacket(PKT_TYPE_TUNNEL_AUTH_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *Server) channelRequest(data []byte) (server string, port uint16) {
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
var resourcesSize byte
|
||||
var alternative byte
|
||||
var protocol uint16
|
||||
var nameSize uint16
|
||||
|
||||
binary.Read(buf, binary.LittleEndian, &resourcesSize)
|
||||
binary.Read(buf, binary.LittleEndian, &alternative)
|
||||
binary.Read(buf, binary.LittleEndian, &port)
|
||||
binary.Read(buf, binary.LittleEndian, &protocol)
|
||||
binary.Read(buf, binary.LittleEndian, &nameSize)
|
||||
|
||||
nameData := make([]byte, nameSize)
|
||||
binary.Read(buf, binary.LittleEndian, &nameData)
|
||||
|
||||
server, _ = DecodeUTF16(nameData)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) channelResponse() []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0)) // error code
|
||||
binary.Write(buf, binary.LittleEndian, uint16(HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID)) // fields present
|
||||
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
|
||||
|
||||
// channel id is required for Windows clients
|
||||
binary.Write(buf, binary.LittleEndian, uint32(1)) // channel id
|
||||
|
||||
// optional fields
|
||||
// channel id uint32 (4)
|
||||
// udp port uint16 (2)
|
||||
// udp auth cookie 1 byte for side channel
|
||||
// length uint16
|
||||
|
||||
return createPacket(PKT_TYPE_CHANNEL_RESPONSE, buf.Bytes())
|
||||
}
|
||||
|
||||
func makeRedirectFlags(flags RedirectFlags) int {
|
||||
var redir = 0
|
||||
|
||||
if flags.DisableAll {
|
||||
return HTTP_TUNNEL_REDIR_DISABLE_ALL
|
||||
}
|
||||
if flags.EnableAll {
|
||||
return HTTP_TUNNEL_REDIR_ENABLE_ALL
|
||||
}
|
||||
|
||||
if !flags.Port {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PORT
|
||||
}
|
||||
if !flags.Clipboard {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD
|
||||
}
|
||||
if !flags.Drive {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_DRIVE
|
||||
}
|
||||
if !flags.Pnp {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PNP
|
||||
}
|
||||
if !flags.Printer {
|
||||
redir = redir | HTTP_TUNNEL_REDIR_DISABLE_PRINTER
|
||||
}
|
||||
return redir
|
||||
}
|
||||
78
cmd/rdpgw/protocol/types.go
Normal file
78
cmd/rdpgw/protocol/types.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package protocol
|
||||
|
||||
const (
|
||||
PKT_TYPE_HANDSHAKE_REQUEST = 0x1
|
||||
PKT_TYPE_HANDSHAKE_RESPONSE = 0x2
|
||||
PKT_TYPE_EXTENDED_AUTH_MSG = 0x3
|
||||
PKT_TYPE_TUNNEL_CREATE = 0x4
|
||||
PKT_TYPE_TUNNEL_RESPONSE = 0x5
|
||||
PKT_TYPE_TUNNEL_AUTH = 0x6
|
||||
PKT_TYPE_TUNNEL_AUTH_RESPONSE = 0x7
|
||||
PKT_TYPE_CHANNEL_CREATE = 0x8
|
||||
PKT_TYPE_CHANNEL_RESPONSE = 0x9
|
||||
PKT_TYPE_DATA = 0xA
|
||||
PKT_TYPE_SERVICE_MESSAGE = 0xB
|
||||
PKT_TYPE_REAUTH_MESSAGE = 0xC
|
||||
PKT_TYPE_KEEPALIVE = 0xD
|
||||
PKT_TYPE_CLOSE_CHANNEL = 0x10
|
||||
PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11
|
||||
)
|
||||
|
||||
const (
|
||||
HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID = 0x01
|
||||
HTTP_TUNNEL_RESPONSE_FIELD_CAPS = 0x02
|
||||
HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ = 0x04
|
||||
HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10
|
||||
)
|
||||
|
||||
const (
|
||||
HTTP_EXTENDED_AUTH_NONE = 0x0
|
||||
HTTP_EXTENDED_AUTH_SC = 0x1 /* Smart card authentication. */
|
||||
HTTP_EXTENDED_AUTH_PAA = 0x02 /* Pluggable authentication. */
|
||||
HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */
|
||||
)
|
||||
|
||||
const (
|
||||
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS = 0x01
|
||||
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02
|
||||
HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04
|
||||
)
|
||||
|
||||
const (
|
||||
HTTP_TUNNEL_REDIR_ENABLE_ALL = 0x80000000
|
||||
HTTP_TUNNEL_REDIR_DISABLE_ALL = 0x40000000
|
||||
HTTP_TUNNEL_REDIR_DISABLE_DRIVE = 0x01
|
||||
HTTP_TUNNEL_REDIR_DISABLE_PRINTER = 0x02
|
||||
HTTP_TUNNEL_REDIR_DISABLE_PORT = 0x04
|
||||
HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08
|
||||
HTTP_TUNNEL_REDIR_DISABLE_PNP = 0x10
|
||||
)
|
||||
|
||||
const (
|
||||
HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID = 0x01
|
||||
HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02
|
||||
HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT = 0x04
|
||||
)
|
||||
|
||||
const (
|
||||
HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1
|
||||
)
|
||||
|
||||
const (
|
||||
SERVER_STATE_INITIAL = 0x0
|
||||
SERVER_STATE_HANDSHAKE = 0x1
|
||||
SERVER_STATE_TUNNEL_CREATE = 0x2
|
||||
SERVER_STATE_TUNNEL_AUTHORIZE = 0x3
|
||||
SERVER_STATE_CHANNEL_CREATE = 0x4
|
||||
SERVER_STATE_OPENED = 0x5
|
||||
SERVER_STATE_CLOSED = 0x6
|
||||
)
|
||||
|
||||
const (
|
||||
HTTP_CAPABILITY_TYPE_QUAR_SOH = 0x1
|
||||
HTTP_CAPABILITY_IDLE_TIMEOUT = 0x2
|
||||
HTTP_CAPABILITY_MESSAGING_CONSENT_SIGN = 0x4
|
||||
HTTP_CAPABILITY_MESSAGING_SERVICE_MSG = 0x8
|
||||
HTTP_CAPABILITY_REAUTH = 0x10
|
||||
HTTP_CAPABILITY_UDP_TRANSPORT = 0x20
|
||||
)
|
||||
42
cmd/rdpgw/protocol/utf16.go
Normal file
42
cmd/rdpgw/protocol/utf16.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"unicode/utf16"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func DecodeUTF16(b []byte) (string, error) {
|
||||
if len(b)%2 != 0 {
|
||||
return "", fmt.Errorf("must have even length byte slice")
|
||||
}
|
||||
|
||||
u16s := make([]uint16, 1)
|
||||
ret := &bytes.Buffer{}
|
||||
b8buf := make([]byte, 4)
|
||||
|
||||
lb := len(b)
|
||||
for i := 0; i < lb; i += 2 {
|
||||
u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8)
|
||||
r := utf16.Decode(u16s)
|
||||
n := utf8.EncodeRune(b8buf, r[0])
|
||||
ret.Write(b8buf[:n])
|
||||
}
|
||||
|
||||
bret := ret.Bytes()
|
||||
if len(bret) > 0 && bret[len(bret)-1] == '\x00' {
|
||||
bret = bret[:len(bret)-1]
|
||||
}
|
||||
return string(bret), nil
|
||||
}
|
||||
|
||||
func EncodeUTF16(s string) []byte {
|
||||
ret := new(bytes.Buffer)
|
||||
enc := utf16.Encode([]rune(s))
|
||||
for c := range enc {
|
||||
binary.Write(ret, binary.LittleEndian, enc[c])
|
||||
}
|
||||
return ret.Bytes()
|
||||
}
|
||||
240
cmd/rdpgw/security/jwt.go
Normal file
240
cmd/rdpgw/security/jwt.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/square/go-jose/v3"
|
||||
"github.com/square/go-jose/v3/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
SigningKey []byte
|
||||
EncryptionKey []byte
|
||||
UserSigningKey []byte
|
||||
UserEncryptionKey []byte
|
||||
OIDCProvider *oidc.Provider
|
||||
Oauth2Config oauth2.Config
|
||||
)
|
||||
|
||||
var ExpiryTime time.Duration = 5
|
||||
var VerifyClientIP bool = true
|
||||
|
||||
type customClaims struct {
|
||||
RemoteServer string `json:"remoteServer"`
|
||||
ClientIP string `json:"clientIp"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
||||
|
||||
func VerifyPAAToken(ctx context.Context, tokenString string) (bool, error) {
|
||||
token, err := jwt.ParseSigned(tokenString)
|
||||
|
||||
// check if the signing algo matches what we expect
|
||||
for _, header := range token.Headers {
|
||||
if header.Algorithm != string(jose.HS256) {
|
||||
return false, fmt.Errorf("unexpected signing method: %v", header.Algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
standard := jwt.Claims{}
|
||||
custom := customClaims{}
|
||||
|
||||
// Claims automagically checks the signature...
|
||||
err = token.Claims(SigningKey, &standard, &custom)
|
||||
if err != nil {
|
||||
log.Printf("token signature validation failed due to %s", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// ...but doesn't check the expiry claim :/
|
||||
err = standard.Validate(jwt.Expected{
|
||||
Issuer: "rdpgw",
|
||||
Time: time.Now(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("token validation failed due to %s", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// validate the access token
|
||||
tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken})
|
||||
_, err = OIDCProvider.UserInfo(ctx, tokenSource)
|
||||
if err != nil {
|
||||
log.Printf("Cannot get user info for access token: %s", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
s := getSessionInfo(ctx)
|
||||
|
||||
s.RemoteServer = custom.RemoteServer
|
||||
s.ClientIp = custom.ClientIP
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func VerifyServerFunc(ctx context.Context, host string) (bool, error) {
|
||||
s := getSessionInfo(ctx)
|
||||
if s == nil {
|
||||
return false, errors.New("no valid session info found in context")
|
||||
}
|
||||
|
||||
if s.RemoteServer != host {
|
||||
log.Printf("Client specified host %s does not match token host %s", host, s.RemoteServer)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if VerifyClientIP && s.ClientIp != common.GetClientIp(ctx) {
|
||||
log.Printf("Current client ip address %s does not match token client ip %s",
|
||||
common.GetClientIp(ctx), s.ClientIp)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) {
|
||||
if len(SigningKey) < 32 {
|
||||
return "", errors.New("token signing key not long enough or not specified")
|
||||
}
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil)
|
||||
if err != nil {
|
||||
log.Printf("Cannot obtain signer %s", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
standard := jwt.Claims{
|
||||
Issuer: "rdpgw",
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
|
||||
Subject: username,
|
||||
}
|
||||
|
||||
private := customClaims{
|
||||
RemoteServer: server,
|
||||
ClientIP: common.GetClientIp(ctx),
|
||||
AccessToken: common.GetAccessToken(ctx),
|
||||
}
|
||||
|
||||
if token, err := jwt.Signed(sig).Claims(standard).Claims(private).CompactSerialize(); err != nil {
|
||||
log.Printf("Cannot sign PAA token %s", err)
|
||||
return "", err
|
||||
} else {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateUserToken(ctx context.Context, userName string) (string, error) {
|
||||
if len(UserEncryptionKey) < 32 {
|
||||
return "", errors.New("user token encryption key not long enough or not specified")
|
||||
}
|
||||
|
||||
claims := jwt.Claims{
|
||||
Subject: userName,
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
|
||||
Issuer: "rdpgw",
|
||||
}
|
||||
|
||||
enc, err := jose.NewEncrypter(
|
||||
jose.A128CBC_HS256,
|
||||
jose.Recipient{Algorithm: jose.DIRECT, Key: UserEncryptionKey},
|
||||
(&jose.EncrypterOptions{Compression: jose.DEFLATE}).WithContentType("JWT"),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Cannot encrypt user token due to %s", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// this makes the token bigger and we deal with a limited space of 511 characters
|
||||
// sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil)
|
||||
// token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).CompactSerialize()
|
||||
token, err := jwt.Encrypted(enc).Claims(claims).CompactSerialize()
|
||||
return token, err
|
||||
}
|
||||
|
||||
func UserInfo(ctx context.Context, token string) (jwt.Claims, error) {
|
||||
standard := jwt.Claims{}
|
||||
if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 {
|
||||
enc, err := jwt.ParseSignedAndEncrypted(token)
|
||||
if err != nil {
|
||||
log.Printf("Cannot get token %s", err)
|
||||
return standard, errors.New("cannot get token")
|
||||
}
|
||||
token, err := enc.Decrypt(UserEncryptionKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot decrypt token %s", err)
|
||||
return standard, errors.New("cannot decrypt token")
|
||||
}
|
||||
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
|
||||
log.Printf("signature validation failure: %s", err)
|
||||
return standard, errors.New("signature validation failure")
|
||||
}
|
||||
if err = token.Claims(UserSigningKey, &standard); err != nil {
|
||||
log.Printf("cannot verify signature %s", err)
|
||||
return standard, errors.New("cannot verify signature")
|
||||
}
|
||||
} else if len(UserSigningKey) == 0 {
|
||||
token, err := jwt.ParseEncrypted(token)
|
||||
if err != nil {
|
||||
log.Printf("Cannot get token %s", err)
|
||||
return standard, errors.New("cannot get token")
|
||||
}
|
||||
err = token.Claims(UserEncryptionKey, &standard)
|
||||
if err != nil {
|
||||
log.Printf("Cannot decrypt token %s", err)
|
||||
return standard, errors.New("cannot decrypt token")
|
||||
}
|
||||
} else {
|
||||
token, err := jwt.ParseSigned(token)
|
||||
if err != nil {
|
||||
log.Printf("Cannot get token %s", err)
|
||||
return standard, errors.New("cannot get token")
|
||||
}
|
||||
if _, err := verifyAlg(token.Headers, string(jose.HS256)); err != nil {
|
||||
log.Printf("signature validation failure: %s", err)
|
||||
return standard, errors.New("signature validation failure")
|
||||
}
|
||||
err = token.Claims(UserSigningKey, &standard)
|
||||
if err = token.Claims(UserSigningKey, &standard); err != nil {
|
||||
log.Printf("cannot verify signature %s", err)
|
||||
return standard, errors.New("cannot verify signature")
|
||||
}
|
||||
}
|
||||
|
||||
// go-jose doesnt verify the expiry
|
||||
err := standard.Validate(jwt.Expected{
|
||||
Issuer: "rdpgw",
|
||||
Time: time.Now(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("token validation failed due to %s", err)
|
||||
return standard, fmt.Errorf("token validation failed due to %s", err)
|
||||
}
|
||||
|
||||
return standard, nil
|
||||
}
|
||||
|
||||
func getSessionInfo(ctx context.Context) *protocol.SessionInfo {
|
||||
s, ok := ctx.Value("SessionInfo").(*protocol.SessionInfo)
|
||||
if !ok {
|
||||
log.Printf("cannot get session info from context")
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func verifyAlg(headers []jose.Header, alg string) (bool, error) {
|
||||
for _, header := range headers {
|
||||
if header.Algorithm != alg {
|
||||
return false, fmt.Errorf("invalid signing method %s", header.Algorithm)
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
82
cmd/rdpgw/transport/legacy.go
Normal file
82
cmd/rdpgw/transport/legacy.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
crlf = "\r\n"
|
||||
HttpOK = "HTTP/1.1 200 OK\r\n"
|
||||
)
|
||||
|
||||
type LegacyPKT struct {
|
||||
Conn net.Conn
|
||||
ChunkedReader io.Reader
|
||||
Writer *bufio.Writer
|
||||
}
|
||||
|
||||
func NewLegacy(w http.ResponseWriter) (*LegacyPKT, error) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if ok {
|
||||
conn, rw, err := hj.Hijack()
|
||||
l := &LegacyPKT{
|
||||
Conn: conn,
|
||||
ChunkedReader: httputil.NewChunkedReader(rw.Reader),
|
||||
Writer: rw.Writer,
|
||||
}
|
||||
return l, err
|
||||
}
|
||||
|
||||
return nil, errors.New("cannot hijack connection")
|
||||
}
|
||||
|
||||
func (t *LegacyPKT) ReadPacket() (n int, p []byte, err error){
|
||||
buf := make([]byte, 4096) // bufio.defaultBufSize
|
||||
n, err = t.ChunkedReader.Read(buf)
|
||||
p = make([]byte, n)
|
||||
copy(p, buf)
|
||||
|
||||
return n, p, err
|
||||
}
|
||||
|
||||
func (t *LegacyPKT) WritePacket(b []byte) (n int, err error) {
|
||||
return t.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (t *LegacyPKT) Close() error {
|
||||
return t.Conn.Close()
|
||||
}
|
||||
|
||||
// [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0
|
||||
// The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes).
|
||||
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
|
||||
// not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the
|
||||
// connection after the last byte is sent.
|
||||
func (t *LegacyPKT) SendAccept(doSeed bool) {
|
||||
t.Writer.WriteString(HttpOK)
|
||||
t.Writer.WriteString("Date: " + time.Now().Format(time.RFC1123) + crlf)
|
||||
if !doSeed {
|
||||
t.Writer.WriteString("Content-Length: 0" + crlf)
|
||||
}
|
||||
t.Writer.WriteString(crlf)
|
||||
|
||||
if doSeed {
|
||||
seed := make([]byte, 10)
|
||||
rand.Read(seed)
|
||||
// docs say it's a seed but 2019 responds with ab cd * 5
|
||||
t.Writer.Write(seed)
|
||||
}
|
||||
t.Writer.Flush()
|
||||
}
|
||||
|
||||
func (t *LegacyPKT) Drain() {
|
||||
p := make([]byte, 32767)
|
||||
t.Conn.Read(p)
|
||||
}
|
||||
8
cmd/rdpgw/transport/transport.go
Normal file
8
cmd/rdpgw/transport/transport.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package transport
|
||||
|
||||
type Transport interface {
|
||||
ReadPacket() (n int, p []byte, err error)
|
||||
WritePacket(b []byte) (n int, err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
42
cmd/rdpgw/transport/websocket.go
Normal file
42
cmd/rdpgw/transport/websocket.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type WSPKT struct {
|
||||
Conn *websocket.Conn
|
||||
}
|
||||
|
||||
func NewWS(c *websocket.Conn) (*WSPKT, error) {
|
||||
w := &WSPKT{Conn: c}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (t *WSPKT) ReadPacket() (n int, b []byte, err error) {
|
||||
mt, msg, err := t.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
return 0, []byte{0, 0}, err
|
||||
}
|
||||
|
||||
if mt == websocket.BinaryMessage {
|
||||
return len(msg), msg, nil
|
||||
}
|
||||
|
||||
return len(msg), msg, errors.New("not a binary packet")
|
||||
}
|
||||
|
||||
func (t *WSPKT) WritePacket(b []byte) (n int, err error) {
|
||||
err = t.Conn.WriteMessage(websocket.BinaryMessage, b)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (t *WSPKT) Close() error {
|
||||
return t.Conn.Close()
|
||||
}
|
||||
Reference in New Issue
Block a user