Websocket refactor

This commit is contained in:
Bolke de Bruin
2020-07-20 15:26:36 +02:00
parent 2f78a7fd8e
commit 33290f59e6
4 changed files with 380 additions and 93 deletions

89
rdg.go
View File

@@ -3,7 +3,6 @@ package main
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/bolkedebruin/rdpgw/protocol"
"github.com/bolkedebruin/rdpgw/transport"
@@ -157,80 +156,12 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
}
func handleWebsocketProtocol(c *websocket.Conn) {
var remote net.Conn
websocketConnections.Inc()
defer websocketConnections.Dec()
inout, _ := transport.NewWS(c)
handler := protocol.NewHandler(inout)
var host string
for {
pt, sz, pkt, err := handler.ReadMessage()
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return
}
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth)
log.Printf("Handshake response: %x", msg)
inout.WritePacket(msg)
case PKT_TYPE_TUNNEL_CREATE:
readCreateTunnelRequest(pkt)
/*data, found := tokens.Get(cookie)
if found == false {
log.Printf("Invalid PAA cookie: %s from %s", cookie, inout.Conn.RemoteAddr())
return
}*/
host = conf.Server.HostTemplate
/*
for k, v := range data.(map[string]interface{}) {
if val, ok := v.(string); ok == true {
host = strings.Replace(host, "{{ " + k + " }}", val, 1)
}
}*/
msg := createTunnelResponse()
log.Printf("Create tunnel response: %x", msg)
inout.WritePacket(msg)
case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse()
log.Printf("Create tunnel auth response: %x", msg)
inout.WritePacket(msg)
case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt)
if conf.Server.EnableOverride == true {
log.Printf("Override allowed")
host = net.JoinHostPort(server, strconv.Itoa(int(port)))
}
log.Printf("Establishing connection to RDP server: %s", host)
remote, err = net.DialTimeout(
"tcp",
host,
time.Second * 30)
if err != nil {
log.Printf("Error connecting to %s", host)
return
}
log.Printf("Connection established")
msg := createChannelCreateResponse()
log.Printf("Create channel create response: %x", msg)
inout.WritePacket(msg)
go sendDataPacket(remote, inout)
case PKT_TYPE_DATA:
forwardDataPacket(remote, pkt)
case PKT_TYPE_KEEPALIVE:
// do not write to make sure we do not create concurrency issues
// inout.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
break
default:
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt)
}
}
handler.Process()
}
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
@@ -537,24 +468,6 @@ func forwardDataPacket(conn net.Conn, data []byte) {
conn.Write(pkt)
}
func handleWebsocketData(rdp net.Conn, conn transport.Transport) {
defer rdp.Close()
b1 := new(bytes.Buffer)
buf := make([]byte, 4086)
for {
n, err := rdp.Read(buf)
binary.Write(b1, binary.LittleEndian, uint16(n))
if err != nil {
log.Printf("Error reading from conn %s", err)
break
}
b1.Write(buf[:n])
conn.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset()
}
}
func sendDataPacket(connIn net.Conn, connOut transport.Transport) {
defer connIn.Close()
b1 := new(bytes.Buffer)