mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-31 16:06:35 +00:00
Make sure to use right keys
This commit is contained in:
@@ -27,10 +27,10 @@ type ClientConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientConfig) ConnectAndForward() error {
|
func (c *ClientConfig) ConnectAndForward() error {
|
||||||
c.Session.TransportOut.WritePacket(c.handshakeRequest())
|
c.Session.transportOut.WritePacket(c.handshakeRequest())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
pt, sz, pkt, err := readMessage(c.Session.TransportIn)
|
pt, sz, pkt, err := readMessage(c.Session.transportIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot read message from stream %s", err)
|
log.Printf("Cannot read message from stream %s", err)
|
||||||
return err
|
return err
|
||||||
@@ -44,7 +44,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Printf("Handshake response received. Caps: %d", caps)
|
log.Printf("Handshake response received. Caps: %d", caps)
|
||||||
c.Session.TransportOut.WritePacket(c.tunnelRequest())
|
c.Session.transportOut.WritePacket(c.tunnelRequest())
|
||||||
case PKT_TYPE_TUNNEL_RESPONSE:
|
case PKT_TYPE_TUNNEL_RESPONSE:
|
||||||
tid, caps, err := c.tunnelResponse(pkt)
|
tid, caps, err := c.tunnelResponse(pkt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -52,7 +52,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
|
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
|
||||||
c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
|
c.Session.transportOut.WritePacket(c.tunnelAuthRequest())
|
||||||
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
|
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
|
||||||
flags, timeout, err := c.tunnelAuthResponse(pkt)
|
flags, timeout, err := c.tunnelAuthResponse(pkt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -60,7 +60,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
|
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
|
||||||
c.Session.TransportOut.WritePacket(c.channelRequest())
|
c.Session.transportOut.WritePacket(c.channelRequest())
|
||||||
case PKT_TYPE_CHANNEL_RESPONSE:
|
case PKT_TYPE_CHANNEL_RESPONSE:
|
||||||
cid, err := c.channelResponse(pkt)
|
cid, err := c.channelResponse(pkt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -71,7 +71,7 @@ func (c *ClientConfig) ConnectAndForward() error {
|
|||||||
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
|
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
|
||||||
}
|
}
|
||||||
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
log.Printf("Channel creation succesful. Channel id: %d", cid)
|
||||||
//go forward(c.LocalConn, c.Session.TransportOut)
|
//go forward(c.LocalConn, c.Session.transportOut)
|
||||||
case PKT_TYPE_DATA:
|
case PKT_TYPE_DATA:
|
||||||
receive(pkt, c.LocalConn)
|
receive(pkt, c.LocalConn)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -165,8 +165,8 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
|||||||
defer inout.Close()
|
defer inout.Close()
|
||||||
|
|
||||||
t.Id = uuid.New().String()
|
t.Id = uuid.New().String()
|
||||||
t.TransportOut = inout
|
t.transportOut = inout
|
||||||
t.TransportIn = inout
|
t.transportIn = inout
|
||||||
t.ConnectedOn = time.Now()
|
t.ConnectedOn = time.Now()
|
||||||
|
|
||||||
handler := NewProcessor(g, t)
|
handler := NewProcessor(g, t)
|
||||||
@@ -179,17 +179,17 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
|
|||||||
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
|
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
|
||||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
|
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
|
||||||
log.Printf("Session %t, %t, %t", t.RDGId, t.TransportOut != nil, t.TransportIn != nil)
|
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
out, err := transport.NewLegacy(w)
|
out, err := transport.NewLegacy(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("cannot hijack connection to support RDG OUT data channel: %t", err)
|
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Opening RDGOUT for client %t", common.GetClientIp(r.Context()))
|
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
|
||||||
|
|
||||||
t.TransportOut = out
|
t.transportOut = out
|
||||||
out.SendAccept(true)
|
out.SendAccept(true)
|
||||||
|
|
||||||
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||||
@@ -199,23 +199,23 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
|
|||||||
|
|
||||||
in, err := transport.NewLegacy(w)
|
in, err := transport.NewLegacy(w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("cannot hijack connection to support RDG IN data channel: %t", err)
|
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer in.Close()
|
defer in.Close()
|
||||||
|
|
||||||
if t.TransportIn == nil {
|
if t.transportIn == nil {
|
||||||
t.Id = uuid.New().String()
|
t.Id = uuid.New().String()
|
||||||
t.TransportIn = in
|
t.transportIn = in
|
||||||
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
||||||
|
|
||||||
log.Printf("Opening RDGIN for client %t", common.GetClientIp(r.Context()))
|
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
|
||||||
in.SendAccept(false)
|
in.SendAccept(false)
|
||||||
|
|
||||||
// read some initial data
|
// read some initial data
|
||||||
in.Drain()
|
in.Drain()
|
||||||
|
|
||||||
log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
|
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
|
||||||
handler := NewProcessor(g, t)
|
handler := NewProcessor(g, t)
|
||||||
RegisterTunnel(t, handler)
|
RegisterTunnel(t, handler)
|
||||||
defer RemoveTunnel(t)
|
defer RemoveTunnel(t)
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// avoid concurrency issues
|
// avoid concurrency issues
|
||||||
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
// p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
|
||||||
case PKT_TYPE_CLOSE_CHANNEL:
|
case PKT_TYPE_CLOSE_CHANNEL:
|
||||||
log.Printf("Close channel")
|
log.Printf("Close channel")
|
||||||
if p.state != SERVER_STATE_OPENED {
|
if p.state != SERVER_STATE_OPENED {
|
||||||
@@ -168,8 +168,8 @@ func (p *Processor) Process(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
msg := p.channelCloseResponse(ERROR_SUCCESS)
|
msg := p.channelCloseResponse(ERROR_SUCCESS)
|
||||||
p.tunnel.Write(msg)
|
p.tunnel.Write(msg)
|
||||||
//p.tunnel.TransportIn.Close()
|
//p.tunnel.transportIn.Close()
|
||||||
//p.tunnel.TransportOut.Close()
|
//p.tunnel.transportOut.Close()
|
||||||
p.state = SERVER_STATE_CLOSED
|
p.state = SERVER_STATE_CLOSED
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ type Tunnel struct {
|
|||||||
// The connection-id (RDG-ConnID) as reported by the client
|
// The connection-id (RDG-ConnID) as reported by the client
|
||||||
RDGId string
|
RDGId string
|
||||||
// The underlying incoming transport being either websocket or legacy http
|
// The underlying incoming transport being either websocket or legacy http
|
||||||
// in case of websocket TransportOut will equal TransportIn
|
// in case of websocket transportOut will equal transportIn
|
||||||
TransportIn transport.Transport
|
transportIn transport.Transport
|
||||||
// The underlying outgoing transport being either websocket or legacy http
|
// The underlying outgoing transport being either websocket or legacy http
|
||||||
// in case of websocket TransportOut will equal TransportOut
|
// in case of websocket transportOut will equal transportOut
|
||||||
TransportOut transport.Transport
|
transportOut transport.Transport
|
||||||
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
|
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
|
||||||
TargetServer string
|
TargetServer string
|
||||||
// The obtained client ip address
|
// The obtained client ip address
|
||||||
@@ -43,7 +43,7 @@ type Tunnel struct {
|
|||||||
|
|
||||||
// Write puts the packet on the transport and updates the statistics for bytes sent
|
// Write puts the packet on the transport and updates the statistics for bytes sent
|
||||||
func (t *Tunnel) Write(pkt []byte) {
|
func (t *Tunnel) Write(pkt []byte) {
|
||||||
n, _ := t.TransportOut.WritePacket(pkt)
|
n, _ := t.transportOut.WritePacket(pkt)
|
||||||
t.BytesSent += int64(n)
|
t.BytesSent += int64(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ func (t *Tunnel) Write(pkt []byte) {
|
|||||||
// packet, with the header removed, and the packet size. It updates the
|
// packet, with the header removed, and the packet size. It updates the
|
||||||
// statistics for bytes received
|
// statistics for bytes received
|
||||||
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
|
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
|
||||||
pt, size, pkt, err = readMessage(t.TransportIn)
|
pt, size, pkt, err = readMessage(t.transportIn)
|
||||||
t.BytesReceived += int64(size)
|
t.BytesReceived += int64(size)
|
||||||
t.LastSeen = time.Now()
|
t.LastSeen = time.Now()
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ import (
|
|||||||
var (
|
var (
|
||||||
info = protocol.Tunnel{
|
info = protocol.Tunnel{
|
||||||
RDGId: "myid",
|
RDGId: "myid",
|
||||||
TransportIn: nil,
|
|
||||||
TransportOut: nil,
|
|
||||||
TargetServer: "my.remote.server",
|
TargetServer: "my.remote.server",
|
||||||
RemoteAddr: "10.0.0.1",
|
RemoteAddr: "10.0.0.1",
|
||||||
UserName: "Frank",
|
UserName: "Frank",
|
||||||
|
|||||||
@@ -289,7 +289,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getTunnel(ctx context.Context) *protocol.Tunnel {
|
func getTunnel(ctx context.Context) *protocol.Tunnel {
|
||||||
s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
|
s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("cannot get session info from context")
|
log.Printf("cannot get session info from context")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
Reference in New Issue
Block a user