Rework transports

This commit is contained in:
Bolke de Bruin
2020-07-20 12:24:40 +02:00
parent ba679b1266
commit 9209f9152d
4 changed files with 166 additions and 58 deletions

92
rdg.go
View File

@@ -6,15 +6,14 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"io" "io"
"log" "log"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -118,8 +117,8 @@ type RdgSession struct {
ConnId string ConnId string
CorrelationId string CorrelationId string
UserId string UserId string
ConnIn net.Conn TransportIn transport.HttpLayer
ConnOut net.Conn TransportOut transport.HttpLayer
StateIn int StateIn int
StateOut int StateOut int
Remote net.Conn Remote net.Conn
@@ -288,18 +287,18 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
s = x.(RdgSession) s = x.(RdgSession)
} }
log.Printf("Session %s, %t, %t", s.ConnId, s.ConnOut != nil, s.ConnIn != nil) log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
if r.Method == MethodRDGOUT { if r.Method == MethodRDGOUT {
conn, rw, err := Accept(w) out, err := transport.NewLegacy(w)
if err != nil { if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return return
} }
log.Printf("Opening RDGOUT for client %s", conn.RemoteAddr().String()) log.Printf("Opening RDGOUT for client %s", out.Conn.RemoteAddr().String())
s.ConnOut = conn s.TransportOut = out
WriteAcceptSeed(rw.Writer, true) out.SendAccept(true)
c.Set(connId, s, cache.DefaultExpiration) c.Set(connId, s, cache.DefaultExpiration)
} else if r.Method == MethodRDGIN { } else if r.Method == MethodRDGIN {
@@ -308,31 +307,31 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
var remote net.Conn var remote net.Conn
conn, rw, err := Accept(w) in, err := transport.NewLegacy(w)
if err != nil { if err != nil {
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err) log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
return return
} }
defer conn.Close() defer in.Close()
if s.ConnIn == nil { if s.TransportIn == nil {
fragment := false fragment := false
index := 0 index := 0
buf := make([]byte, 4096) buf := make([]byte, 4096)
s.ConnIn = conn s.TransportIn = in
c.Set(connId, s, cache.DefaultExpiration) c.Set(connId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", conn.RemoteAddr().String())
WriteAcceptSeed(rw.Writer, false)
p := make([]byte, 32767)
rw.Reader.Read(p)
log.Printf("Reading packet from client %s", conn.RemoteAddr().String()) //log.Printf("Opening RDGIN for client %s", in.RemoteAddr().String())
chunkScanner := httputil.NewChunkedReader(rw.Reader) in.SendAccept(false)
msg := make([]byte, 4096) // bufio.defaultBufSize
// read some initial data
in.Drain()
log.Printf("Reading packet from client %s", in.Conn.RemoteAddr().String())
for { for {
n, err := chunkScanner.Read(msg) n, msg, err := in.ReadPacket()
if err == io.EOF || n == 0 { if err == io.EOF || n == 0 {
break break
} }
@@ -360,19 +359,19 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
case PKT_TYPE_HANDSHAKE_REQUEST: case PKT_TYPE_HANDSHAKE_REQUEST:
major, minor, _, auth := readHandshake(pkt) major, minor, _, auth := readHandshake(pkt)
msg := handshakeResponse(major, minor, auth) msg := handshakeResponse(major, minor, auth)
s.ConnOut.Write(msg) s.TransportOut.WritePacket(msg)
case PKT_TYPE_TUNNEL_CREATE: case PKT_TYPE_TUNNEL_CREATE:
_, cookie := readCreateTunnelRequest(pkt) readCreateTunnelRequest(pkt)
if _, found := tokens.Get(cookie); found == false { /*if _, found := tokens.Get(cookie); found == false {
log.Printf("Invalid PAA cookie: %s from %s", cookie, conn.RemoteAddr()) log.Printf("Invalid PAA cookie: %s from %s", cookie, in.Conn.RemoteAddr())
return return
} }*/
msg := createTunnelResponse() msg := createTunnelResponse()
s.ConnOut.Write(msg) s.TransportOut.WritePacket(msg)
case PKT_TYPE_TUNNEL_AUTH: case PKT_TYPE_TUNNEL_AUTH:
readTunnelAuthRequest(pkt) readTunnelAuthRequest(pkt)
msg := createTunnelAuthResponse() msg := createTunnelAuthResponse()
s.ConnOut.Write(msg) s.TransportOut.WritePacket(msg)
case PKT_TYPE_CHANNEL_CREATE: case PKT_TYPE_CHANNEL_CREATE:
server, port := readChannelCreateRequest(pkt) server, port := readChannelCreateRequest(pkt)
log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server) log.Printf("Establishing connection to RDP server: %s on port %d (%x)", server, port, server)
@@ -386,19 +385,19 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
} }
log.Printf("Connection established") log.Printf("Connection established")
msg := createChannelCreateResponse() msg := createChannelCreateResponse()
s.ConnOut.Write(msg) s.TransportOut.WritePacket(msg)
// Make sure to start the flow from the RDP server first otherwise connections // Make sure to start the flow from the RDP server first otherwise connections
// might hang eventually // might hang eventually
go sendDataPacket(remote, s.ConnOut) go sendDataPacket(remote, s.TransportOut)
case PKT_TYPE_DATA: case PKT_TYPE_DATA:
forwardDataPacket(remote, pkt) forwardDataPacket(remote, pkt)
case PKT_TYPE_KEEPALIVE: case PKT_TYPE_KEEPALIVE:
// avoid concurrency issues // avoid concurrency issues
// s.ConnOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{})) // s.TransportOut.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL: case PKT_TYPE_CLOSE_CHANNEL:
s.ConnIn.Close() s.TransportIn.Close()
s.ConnOut.Close() s.TransportOut.Close()
break break
default: default:
log.Printf("Unknown packet (size %d): %x", sz, pkt[:n]) log.Printf("Unknown packet (size %d): %x", sz, pkt[:n])
@@ -408,29 +407,6 @@ func handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
} }
} }
// [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0
// The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes).
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
// not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the
// connection after the last byte is sent.
func WriteAcceptSeed(bw *bufio.Writer, doSeed bool) {
log.Printf("Writing accept")
bw.WriteString(HttpOK)
bw.WriteString("Date: " + time.Now().Format(time.RFC1123) + crlf)
if !doSeed {
bw.WriteString("Content-Length: 0" + crlf)
}
bw.WriteString(crlf)
if doSeed {
seed := make([]byte, 10)
rand.Read(seed)
// docs say it's a seed but 2019 responds with ab cd * 5
bw.Write(seed)
}
bw.Flush()
}
func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) {
// header needs to be 8 min // header needs to be 8 min
if len(data) < 8 { if len(data) < 8 {
@@ -654,7 +630,7 @@ func handleWebsocketData(rdp net.Conn, mt int, conn *websocket.Conn) {
} }
} }
func sendDataPacket(connIn net.Conn, connOut net.Conn) { func sendDataPacket(connIn net.Conn, connOut transport.HttpLayer) {
defer connIn.Close() defer connIn.Close()
b1 := new(bytes.Buffer) b1 := new(bytes.Buffer)
buf := make([]byte, 4086) buf := make([]byte, 4086)
@@ -667,7 +643,7 @@ func sendDataPacket(connIn net.Conn, connOut net.Conn) {
break break
} }
b1.Write(buf[:n]) b1.Write(buf[:n])
connOut.Write(createPacket(PKT_TYPE_DATA, b1.Bytes())) connOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
b1.Reset() b1.Reset()
} }
} }

82
transport/legacy.go Normal file
View File

@@ -0,0 +1,82 @@
package transport
import (
"bufio"
"errors"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"time"
)
const (
crlf = "\r\n"
HttpOK = "HTTP/1.1 200 OK\r\n"
)
type LegacyPKT struct {
Conn net.Conn
ChunkedReader io.Reader
Writer *bufio.Writer
}
func NewLegacy(w http.ResponseWriter) (*LegacyPKT, error) {
hj, ok := w.(http.Hijacker)
if ok {
conn, rw, err := hj.Hijack()
l := &LegacyPKT{
Conn: conn,
ChunkedReader: httputil.NewChunkedReader(rw.Reader),
Writer: rw.Writer,
}
return l, err
}
return nil, errors.New("cannot hijack connection")
}
func (t *LegacyPKT) ReadPacket() (n int, p []byte, err error){
buf := make([]byte, 4096) // bufio.defaultBufSize
n, err = t.ChunkedReader.Read(buf)
p = make([]byte, n)
copy(p, buf)
return n, p, err
}
func (t *LegacyPKT) WritePacket(b []byte) (n int, err error) {
return t.Conn.Write(b)
}
func (t *LegacyPKT) Close() error {
return t.Conn.Close()
}
// [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0
// The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes).
// This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does
// not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the
// connection after the last byte is sent.
func (t *LegacyPKT) SendAccept(doSeed bool) {
t.Writer.WriteString(HttpOK)
t.Writer.WriteString("Date: " + time.Now().Format(time.RFC1123) + crlf)
if !doSeed {
t.Writer.WriteString("Content-Length: 0" + crlf)
}
t.Writer.WriteString(crlf)
if doSeed {
seed := make([]byte, 10)
rand.Read(seed)
// docs say it's a seed but 2019 responds with ab cd * 5
t.Writer.Write(seed)
}
t.Writer.Flush()
}
func (t *LegacyPKT) Drain() {
p := make([]byte, 32767)
t.Conn.Read(p)
}

8
transport/transport.go Normal file
View File

@@ -0,0 +1,8 @@
package transport
type HttpLayer interface {
ReadPacket() (n int, p []byte, err error)
WritePacket(b []byte) (n int, err error)
Close() error
}

42
transport/websocket.go Normal file
View File

@@ -0,0 +1,42 @@
package transport
import (
"errors"
"github.com/gorilla/websocket"
)
type WSPKT struct {
Conn *websocket.Conn
}
func NewWS(c *websocket.Conn) (*WSPKT, error) {
w := &WSPKT{Conn: c}
return w, nil
}
func (t *WSPKT) ReadPacket() (n int, b []byte, err error) {
mt, msg, err := t.Conn.ReadMessage()
if err != nil {
return 0, []byte{0, 0}, err
}
if mt == websocket.BinaryMessage {
return len(msg), msg, nil
}
return len(msg), msg, errors.New("not a binary packet")
}
func (t *WSPKT) WritePacket(b []byte) (n int, err error) {
err = t.Conn.WriteMessage(websocket.BinaryMessage, b)
if err != nil {
return 0, err
}
return len(b), nil
}
func (t *WSPKT) Close() error {
return t.Conn.Close()
}