Make sure to use right keys

This commit is contained in:
Bolke de Bruin
2022-09-24 16:47:03 +02:00
parent 94d7cddc4b
commit 0566f90488
6 changed files with 27 additions and 29 deletions

View File

@@ -27,10 +27,10 @@ type ClientConfig struct {
}
func (c *ClientConfig) ConnectAndForward() error {
c.Session.TransportOut.WritePacket(c.handshakeRequest())
c.Session.transportOut.WritePacket(c.handshakeRequest())
for {
pt, sz, pkt, err := readMessage(c.Session.TransportIn)
pt, sz, pkt, err := readMessage(c.Session.transportIn)
if err != nil {
log.Printf("Cannot read message from stream %s", err)
return err
@@ -44,7 +44,7 @@ func (c *ClientConfig) ConnectAndForward() error {
return err
}
log.Printf("Handshake response received. Caps: %d", caps)
c.Session.TransportOut.WritePacket(c.tunnelRequest())
c.Session.transportOut.WritePacket(c.tunnelRequest())
case PKT_TYPE_TUNNEL_RESPONSE:
tid, caps, err := c.tunnelResponse(pkt)
if err != nil {
@@ -52,7 +52,7 @@ func (c *ClientConfig) ConnectAndForward() error {
return err
}
log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps)
c.Session.TransportOut.WritePacket(c.tunnelAuthRequest())
c.Session.transportOut.WritePacket(c.tunnelAuthRequest())
case PKT_TYPE_TUNNEL_AUTH_RESPONSE:
flags, timeout, err := c.tunnelAuthResponse(pkt)
if err != nil {
@@ -60,7 +60,7 @@ func (c *ClientConfig) ConnectAndForward() error {
return err
}
log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout)
c.Session.TransportOut.WritePacket(c.channelRequest())
c.Session.transportOut.WritePacket(c.channelRequest())
case PKT_TYPE_CHANNEL_RESPONSE:
cid, err := c.channelResponse(pkt)
if err != nil {
@@ -71,7 +71,7 @@ func (c *ClientConfig) ConnectAndForward() error {
log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid)
}
log.Printf("Channel creation succesful. Channel id: %d", cid)
//go forward(c.LocalConn, c.Session.TransportOut)
//go forward(c.LocalConn, c.Session.transportOut)
case PKT_TYPE_DATA:
receive(pkt, c.LocalConn)
default:

View File

@@ -165,8 +165,8 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
defer inout.Close()
t.Id = uuid.New().String()
t.TransportOut = inout
t.TransportIn = inout
t.transportOut = inout
t.transportIn = inout
t.ConnectedOn = time.Now()
handler := NewProcessor(g, t)
@@ -179,17 +179,17 @@ func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
// to ensure the connections do not get cached or terminated by a proxy prematurely.
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
log.Printf("Session %t, %t, %t", t.RDGId, t.TransportOut != nil, t.TransportIn != nil)
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
if r.Method == MethodRDGOUT {
out, err := transport.NewLegacy(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG OUT data channel: %t", err)
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
return
}
log.Printf("Opening RDGOUT for client %t", common.GetClientIp(r.Context()))
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
t.TransportOut = out
t.transportOut = out
out.SendAccept(true)
c.Set(t.RDGId, t, cache.DefaultExpiration)
@@ -199,23 +199,23 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t
in, err := transport.NewLegacy(w)
if err != nil {
log.Printf("cannot hijack connection to support RDG IN data channel: %t", err)
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
return
}
defer in.Close()
if t.TransportIn == nil {
if t.transportIn == nil {
t.Id = uuid.New().String()
t.TransportIn = in
t.transportIn = in
c.Set(t.RDGId, t, cache.DefaultExpiration)
log.Printf("Opening RDGIN for client %t", common.GetClientIp(r.Context()))
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
in.SendAccept(false)
// read some initial data
in.Drain()
log.Printf("Legacy handshakeRequest done for client %t", common.GetClientIp(r.Context()))
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
handler := NewProcessor(g, t)
RegisterTunnel(t, handler)
defer RemoveTunnel(t)

View File

@@ -159,7 +159,7 @@ func (p *Processor) Process(ctx context.Context) error {
}
// avoid concurrency issues
// p.TransportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
// p.transportIn.Write(createPacket(PKT_TYPE_KEEPALIVE, []byte{}))
case PKT_TYPE_CLOSE_CHANNEL:
log.Printf("Close channel")
if p.state != SERVER_STATE_OPENED {
@@ -168,8 +168,8 @@ func (p *Processor) Process(ctx context.Context) error {
}
msg := p.channelCloseResponse(ERROR_SUCCESS)
p.tunnel.Write(msg)
//p.tunnel.TransportIn.Close()
//p.tunnel.TransportOut.Close()
//p.tunnel.transportIn.Close()
//p.tunnel.transportOut.Close()
p.state = SERVER_STATE_CLOSED
return nil
default:

View File

@@ -12,11 +12,11 @@ type Tunnel struct {
// The connection-id (RDG-ConnID) as reported by the client
RDGId string
// The underlying incoming transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportIn
TransportIn transport.Transport
// in case of websocket transportOut will equal transportIn
transportIn transport.Transport
// The underlying outgoing transport being either websocket or legacy http
// in case of websocket TransportOut will equal TransportOut
TransportOut transport.Transport
// in case of websocket transportOut will equal transportOut
transportOut transport.Transport
// The remote desktop server (rdp, vnc etc) the clients intends to connect to
TargetServer string
// The obtained client ip address
@@ -43,7 +43,7 @@ type Tunnel struct {
// Write puts the packet on the transport and updates the statistics for bytes sent
func (t *Tunnel) Write(pkt []byte) {
n, _ := t.TransportOut.WritePacket(pkt)
n, _ := t.transportOut.WritePacket(pkt)
t.BytesSent += int64(n)
}
@@ -51,7 +51,7 @@ func (t *Tunnel) Write(pkt []byte) {
// packet, with the header removed, and the packet size. It updates the
// statistics for bytes received
func (t *Tunnel) Read() (pt int, size int, pkt []byte, err error) {
pt, size, pkt, err = readMessage(t.TransportIn)
pt, size, pkt, err = readMessage(t.transportIn)
t.BytesReceived += int64(size)
t.LastSeen = time.Now()

View File

@@ -10,8 +10,6 @@ import (
var (
info = protocol.Tunnel{
RDGId: "myid",
TransportIn: nil,
TransportOut: nil,
TargetServer: "my.remote.server",
RemoteAddr: "10.0.0.1",
UserName: "Frank",

View File

@@ -289,7 +289,7 @@ func GenerateQueryToken(ctx context.Context, query string, issuer string) (strin
}
func getTunnel(ctx context.Context) *protocol.Tunnel {
s, ok := ctx.Value("Tunnel").(*protocol.Tunnel)
s, ok := ctx.Value(common.TunnelCtx).(*protocol.Tunnel)
if !ok {
log.Printf("cannot get session info from context")
return nil