mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-27 22:46:37 +00:00
Use context
This commit is contained in:
@@ -2,6 +2,7 @@ package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
@@ -71,7 +72,7 @@ func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
||||
|
||||
const tunnelId = 10
|
||||
|
||||
func (h *Handler) Process() error {
|
||||
func (h *Handler) Process(ctx context.Context) error {
|
||||
for {
|
||||
pt, sz, pkt, err := h.ReadMessage()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/patrickmn/go-cache"
|
||||
@@ -64,6 +65,9 @@ func init() {
|
||||
}
|
||||
|
||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
connectionCache.Set(float64(c.ItemCount()))
|
||||
|
||||
var s *SessionInfo
|
||||
@@ -78,7 +82,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
g.handleLegacyProtocol(w, r, s)
|
||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||
return
|
||||
}
|
||||
r.Method = "GET" // force
|
||||
@@ -89,13 +93,13 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
g.handleWebsocketProtocol(conn, s)
|
||||
g.handleWebsocketProtocol(ctx, conn, s)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
g.handleLegacyProtocol(w, r, s)
|
||||
g.handleLegacyProtocol(w, r.WithContext(ctx), s)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
||||
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, s *SessionInfo) {
|
||||
websocketConnections.Inc()
|
||||
defer websocketConnections.Dec()
|
||||
|
||||
@@ -103,7 +107,7 @@ func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
||||
s.TransportOut = inout
|
||||
s.TransportIn = inout
|
||||
handler := NewHandler(s, g.HandlerConf)
|
||||
handler.Process()
|
||||
handler.Process(ctx)
|
||||
}
|
||||
|
||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||
@@ -147,7 +151,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s
|
||||
|
||||
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
||||
handler := NewHandler(s, g.HandlerConf)
|
||||
handler.Process()
|
||||
handler.Process(r.Context())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user