mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-29 07:06:34 +00:00
Websocket refactor
This commit is contained in:
89
rdg.go
89
rdg.go
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bolkedebruin/rdpgw/protocol"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
@@ -157,80 +156,12 @@ func handleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func handleWebsocketProtocol(c *websocket.Conn) {
|
||||
var remote net.Conn
|
||||
|
||||
websocketConnections.Inc()
|
||||
defer websocketConnections.Dec()
|
||||
|
||||
inout, _ := transport.NewWS(c)
|
||||
handler := protocol.NewHandler(inout)
|
||||
|
||||
var host string
|
||||
for {
|
||||
pt, sz, pkt, err := handler.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("Cannot read message from stream %s", err)
|
||||
return
|
||||
}
|
||||
switch pt {
|
||||
case PKT_TYPE_HANDSHAKE_REQUEST:
|
||||
major, minor, _, auth := readHandshake(pkt)
|
||||
msg := handshakeResponse(major, minor, auth)
|
||||
log.Printf("Handshake response: %x", msg)
|
||||
inout.WritePacket(msg)
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
readCreateTunnelRequest(pkt)
|
||||
/*data, found := tokens.Get(cookie)
|
||||
if found == false {
|
||||
log.Printf("Invalid PAA cookie: %s from %s", cookie, inout.Conn.RemoteAddr())
|
||||
return
|
||||
}*/
|
||||
host = conf.Server.HostTemplate
|
||||
/*
|
||||
for k, v := range data.(map[string]interface{}) {
|
||||
if val, ok := v.(string); ok == true {
|
||||
host = strings.Replace(host, "{{ " + k + " }}", val, 1)
|
||||
}
|
||||
}*/
|
||||
msg := createTunnelResponse()
|
||||
log.Printf("Create tunnel response: %x", msg)
|
||||
inout.WritePacket(msg)
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
readTunnelAuthRequest(pkt)
|
||||
msg := createTunnelAuthResponse()
|
||||
log.Printf("Create tunnel auth response: %x", msg)
|
||||
inout.WritePacket(msg)
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
server, port := readChannelCreateRequest(pkt)
|
||||
if conf.Server.EnableOverride == true {
|
||||
log.Printf("Override allowed")
|
||||
host = net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||
}
|
||||
log.Printf("Establishing connection to RDP server: %s", host)
|
||||
remote, err = net.DialTimeout(
|
||||
"tcp",
|
||||
host,
|
||||
time.Second * 30)
|
||||
if err != nil {
|
||||
log.Printf("Error connecting to %s", host)
|
||||
return
|
||||
}
|
||||
log.Printf("Connection established")
|
||||
msg := createChannelCreateResponse()
|
||||
log.Printf("Create channel create response: %x", msg)
|
||||
inout.WritePacket(msg)
|
||||
go sendDataPacket(remote, inout)
|
||||
case PKT_TYPE_DATA:
|
||||
forwardDataPacket(remote, pkt)
|
||||
case PKT_TYPE_KEEPALIVE:
|
||||
// do not write to make sure we do not create concurrency issues
|
||||
// inout.WriteMessage(mt, createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||
case PKT_TYPE_CLOSE_CHANNEL:
|
||||
break
|
||||
default:
|
||||
log.Printf("Unknown packet type: %d (size: %d), %x", pt, sz, pkt)
|
||||
}
|
||||
}
|
||||
handler.Process()
|
||||
}
|
||||
|
||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||
@@ -537,24 +468,6 @@ func forwardDataPacket(conn net.Conn, data []byte) {
|
||||
conn.Write(pkt)
|
||||
}
|
||||
|
||||
func handleWebsocketData(rdp net.Conn, conn transport.Transport) {
|
||||
defer rdp.Close()
|
||||
b1 := new(bytes.Buffer)
|
||||
buf := make([]byte, 4086)
|
||||
|
||||
for {
|
||||
n, err := rdp.Read(buf)
|
||||
binary.Write(b1, binary.LittleEndian, uint16(n))
|
||||
if err != nil {
|
||||
log.Printf("Error reading from conn %s", err)
|
||||
break
|
||||
}
|
||||
b1.Write(buf[:n])
|
||||
conn.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
b1.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func sendDataPacket(connIn net.Conn, connOut transport.Transport) {
|
||||
defer connIn.Close()
|
||||
b1 := new(bytes.Buffer)
|
||||
|
||||
Reference in New Issue
Block a user