mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-27 14:36:36 +00:00
Refactor some stuff
This commit is contained in:
@@ -51,10 +51,10 @@ type HandlerConf struct {
|
||||
TokenAuth bool
|
||||
}
|
||||
|
||||
func NewHandler(in transport.Transport, out transport.Transport, conf *HandlerConf) *Handler {
|
||||
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
||||
h := &Handler{
|
||||
TransportIn: in,
|
||||
TransportOut: out,
|
||||
TransportIn: s.TransportIn,
|
||||
TransportOut: s.TransportOut,
|
||||
State: SERVER_STATE_INITIAL,
|
||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||
IdleTimeout: conf.IdleTimeout,
|
||||
|
||||
@@ -51,10 +51,9 @@ type SessionInfo struct {
|
||||
TransportOut transport.Transport
|
||||
RemoteAddress string
|
||||
ProxyAddresses string
|
||||
UserName string
|
||||
}
|
||||
|
||||
var DefaultSession SessionInfo
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||
|
||||
@@ -66,9 +65,20 @@ func init() {
|
||||
|
||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
connectionCache.Set(float64(c.ItemCount()))
|
||||
|
||||
var s *SessionInfo
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
s = &SessionInfo{ConnId: connId}
|
||||
} else {
|
||||
s = x.(*SessionInfo)
|
||||
}
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
g.handleLegacyProtocol(w, r)
|
||||
g.handleLegacyProtocol(w, r, s)
|
||||
return
|
||||
}
|
||||
r.Method = "GET" // force
|
||||
@@ -79,35 +89,27 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
g.handleWebsocketProtocol(conn)
|
||||
g.handleWebsocketProtocol(conn, s)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
g.handleLegacyProtocol(w, r)
|
||||
g.handleLegacyProtocol(w, r, s)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn) {
|
||||
func (g *Gateway) handleWebsocketProtocol(c *websocket.Conn, s *SessionInfo) {
|
||||
websocketConnections.Inc()
|
||||
defer websocketConnections.Dec()
|
||||
|
||||
inout, _ := transport.NewWS(c)
|
||||
handler := NewHandler(inout, inout, g.HandlerConf)
|
||||
s.TransportOut = inout
|
||||
s.TransportIn = inout
|
||||
handler := NewHandler(s, g.HandlerConf)
|
||||
handler.Process()
|
||||
}
|
||||
|
||||
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
||||
// and RDG_OUT_DATA for server -> client data. The handshake procedure is a bit different
|
||||
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
var s SessionInfo
|
||||
|
||||
connId := r.Header.Get(rdgConnectionIdKey)
|
||||
x, found := c.Get(connId)
|
||||
if !found {
|
||||
s = SessionInfo{ConnId: connId}
|
||||
} else {
|
||||
s = x.(SessionInfo)
|
||||
}
|
||||
|
||||
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, s *SessionInfo) {
|
||||
log.Printf("Session %s, %t, %t", s.ConnId, s.TransportOut != nil, s.TransportIn != nil)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
@@ -121,7 +123,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
s.TransportOut = out
|
||||
out.SendAccept(true)
|
||||
|
||||
c.Set(connId, s, cache.DefaultExpiration)
|
||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||
} else if r.Method == MethodRDGIN {
|
||||
legacyConnections.Inc()
|
||||
defer legacyConnections.Dec()
|
||||
@@ -135,7 +137,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if s.TransportIn == nil {
|
||||
s.TransportIn = in
|
||||
c.Set(connId, s, cache.DefaultExpiration)
|
||||
c.Set(s.ConnId, s, cache.DefaultExpiration)
|
||||
|
||||
log.Printf("Opening RDGIN for client %s", in.Conn.RemoteAddr().String())
|
||||
in.SendAccept(false)
|
||||
@@ -144,7 +146,7 @@ func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
in.Drain()
|
||||
|
||||
log.Printf("Legacy handshake done for client %s", in.Conn.RemoteAddr().String())
|
||||
handler := NewHandler(in, s.TransportOut, g.HandlerConf)
|
||||
handler := NewHandler(s, g.HandlerConf)
|
||||
handler.Process()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user