This commit is contained in:
Bolke de Bruin
2020-08-01 21:23:34 +02:00
parent 4e99b4e88f
commit 9c19a1b40a
8 changed files with 144 additions and 87 deletions

View File

@@ -4,8 +4,8 @@ import (
"bytes"
"encoding/binary"
"fmt"
"github.com/bolkedebruin/rdpgw/transport"
"io"
"log"
"net"
)
@@ -17,10 +17,67 @@ const (
type ClientConfig struct {
SmartCardAuth bool
PAAToken string
NTLMAuth bool
GatewayConn transport.Transport
LocalConn net.Conn
PAAToken string
NTLMAuth bool
Session *SessionInfo
LocalConn net.Conn
Server string
Port int
Name string
}
func (c *ClientConfig) ConnectAndForward() error {
c.Session.TransportOut.WritePacket(c.handshakeRequest())
for {
pt, sz, pkt, err := readMessage(c.Session.TransportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
}
switch pt {
case PKT_TYPE_HANDSHAKE_RESPONSE:
caps, err := c.handshakeResponse(pkt)
if err != nil {
log.Printf("Cannot connect to %s due to %s", c.Server, err)
return err
}
log.Printf("Handshake response received. Caps: %d", caps)
c.Session.TransportOut.WritePacket(c.tunnelRequest())
case PKT_TYPE_TUNNEL_RESPONSE:
tid, caps, err := c.tunnelResponse(pkt)
if err != nil {
log.Printf("Cannot setup tunnel due to %s", err)
return err
}
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
flags, timeout, err := c.tunnelAuthResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
c.Session.TransportOut.WritePacket(c.channelRequest())
case PKT_TYPE_CHANNEL_RESPONSE:
cid, err := c.channelResponse(pkt)
if err != nil {
log.Printf("Cannot do tunnel auth due to %s", err)
return err
}
if cid < 1 {
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
}
log.Printf("Channel creation succesful. Channel id: %d", cid)
go forward(c.LocalConn, c.Session.TransportOut)
case PKT_TYPE_DATA:
receive(pkt, c.LocalConn)
default:
log.Printf("Unknown packet type received: %d size %d", pt, sz)
}
}
}
func (c *ClientConfig) handshakeRequest() []byte {
@@ -83,7 +140,7 @@ func (c *ClientConfig) tunnelRequest() []byte {
binary.Write(buf, binary.LittleEndian, caps)
binary.Write(buf, binary.LittleEndian, fields)
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved
if len(c.PAAToken) > 0 {
utf16Token := EncodeUTF16(c.PAAToken)
@@ -119,8 +176,8 @@ func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32
return
}
func (c *ClientConfig) tunnelAuthRequest(name string) []byte {
utf16name := EncodeUTF16(name)
func (c *ClientConfig) tunnelAuthRequest() []byte {
utf16name := EncodeUTF16(c.Name)
size := uint16(len(utf16name))
buf := new(bytes.Buffer)
@@ -153,14 +210,14 @@ func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout ui
return
}
func (c *ClientConfig) channelRequest(server string, port uint16) []byte {
utf16server := EncodeUTF16(server)
func (c *ClientConfig) channelRequest() []byte {
utf16server := EncodeUTF16(c.Server)
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
binary.Write(buf, binary.LittleEndian, uint16(port))
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names
binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3)
binary.Write(buf, binary.LittleEndian, uint16(c.Port))
binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3
binary.Write(buf, binary.LittleEndian, uint16(len(utf16server)))
buf.Write(utf16server)

View File

@@ -10,6 +10,62 @@ import (
"net"
)
type RedirectFlags struct {
Clipboard bool
Port bool
Drive bool
Printer bool
Pnp bool
DisableAll bool
EnableAll bool
}
type SessionInfo struct {
ConnId string
TransportIn transport.Transport
TransportOut transport.Transport
RemoteServer string
ClientIp string
}
func readMessage(in transport.Transport) (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := in.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
func createPacket(pktType uint16, data []byte) (packet []byte) {
size := len(data) + 8
buf := new(bytes.Buffer)

View File

@@ -2,7 +2,7 @@ package protocol
import (
"context"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/common"
"github.com/bolkedebruin/rdpgw/transport"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
@@ -45,14 +45,6 @@ type Gateway struct {
ServerConf *ServerConf
}
type SessionInfo struct {
ConnId string
TransportIn transport.Transport
TransportOut transport.Transport
RemoteServer string
ClientIp string
}
var upgrader = websocket.Upgrader{}
var c = cache.New(5*time.Minute, 10*time.Minute)
@@ -118,7 +110,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %s", client.GetClientIp(r.Context()))
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
s.TransportOut = out
out.SendAccept(true)
@@ -139,13 +131,13 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
s.TransportIn = in
c.Set(s.ConnId, s, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %s", client.GetClientIp(r.Context()))
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
in.SendAccept(false)
// read some initial data
in.Drain()
log.Printf("Legacy handshakeRequest done for client %s", client.GetClientIp(r.Context()))
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
handler := NewServer(s, g.ServerConf)
handler.Process(r.Context())
}

View File

@@ -204,4 +204,4 @@ func TestChannelCreation(t *testing.T) {
if channelId < 1 {
t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId)
}
}
}

View File

@@ -5,7 +5,7 @@ import (
"context"
"encoding/binary"
"errors"
"github.com/bolkedebruin/rdpgw/client"
"github.com/bolkedebruin/rdpgw/common"
"io"
"log"
"net"
@@ -17,16 +17,6 @@ type VerifyTunnelCreate func(context.Context, string) (bool, error)
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
type VerifyServerFunc func(context.Context, string) (bool, error)
type RedirectFlags struct {
Clipboard bool
Port bool
Drive bool
Printer bool
Pnp bool
DisableAll bool
EnableAll bool
}
type Server struct {
Session *SessionInfo
VerifyTunnelCreate VerifyTunnelCreate
@@ -70,7 +60,7 @@ const tunnelId = 10
func (s *Server) Process(ctx context.Context) error {
for {
pt, sz, pkt, err := s.ReadMessage()
pt, sz, pkt, err := readMessage(s.Session.TransportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
@@ -78,7 +68,7 @@ func (s *Server) Process(ctx context.Context) error {
switch pt {
case PKT_TYPE_HANDSHAKE_REQUEST:
log.Printf("Client handshakeRequest from %s", client.GetClientIp(ctx))
log.Printf("Client handshakeRequest from %s", common.GetClientIp(ctx))
if s.State != SERVER_STATE_INITIAL {
log.Printf("Handshake attempted while in wrong state %d != %d", s.State, SERVER_STATE_INITIAL)
return errors.New("wrong state")
@@ -97,7 +87,7 @@ func (s *Server) Process(ctx context.Context) error {
_, cookie := s.tunnelRequest(pkt)
if s.VerifyTunnelCreate != nil {
if ok, _ := s.VerifyTunnelCreate(ctx, cookie); !ok {
log.Printf("Invalid PAA cookie received from client %s", client.GetClientIp(ctx))
log.Printf("Invalid PAA cookie received from client %s", common.GetClientIp(ctx))
return errors.New("invalid PAA cookie")
}
}
@@ -181,44 +171,6 @@ func (s *Server) Process(ctx context.Context) error {
}
}
func (s *Server) ReadMessage() (pt int, n int, msg []byte, err error) {
fragment := false
index := 0
buf := make([]byte, 4096)
for {
size, pkt, err := s.Session.TransportIn.ReadPacket()
if err != nil {
return 0, 0, []byte{0, 0}, err
}
// check for fragments
var pt uint16
var sz uint32
var msg []byte
if !fragment {
pt, sz, msg, err = readHeader(pkt[:size])
if err != nil {
fragment = true
index = copy(buf, pkt[:size])
continue
}
index = 0
} else {
fragment = false
pt, sz, msg, err = readHeader(append(buf[:index], pkt[:size]...))
// header is corrupted even after defragmenting
if err != nil {
return 0, 0, []byte{0, 0}, err
}
}
if !fragment {
return int(pt), int(sz), msg, nil
}
}
}
// Creates a packet the is a response to a handshakeRequest request
// HTTP_EXTENDED_AUTH_SSPI_NTLM is not supported in Linux
// but could be in Windows. However the NTLM protocol is insecure