mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-27 14:36:36 +00:00
226 lines
5.9 KiB
Go
226 lines
5.9 KiB
Go
package protocol
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/common"
|
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/patrickmn/go-cache"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"reflect"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
rdgConnectionIdKey = "Rdg-Connection-Id"
|
|
MethodRDGIN = "RDG_IN_DATA"
|
|
MethodRDGOUT = "RDG_OUT_DATA"
|
|
)
|
|
|
|
type CheckPAACookieFunc func(context.Context, string) (bool, error)
|
|
type CheckClientNameFunc func(context.Context, string) (bool, error)
|
|
type CheckHostFunc func(context.Context, string) (bool, error)
|
|
|
|
type Gateway struct {
|
|
// CheckPAACookie verifies if the PAA cookie sent by the client is valid
|
|
CheckPAACookie CheckPAACookieFunc
|
|
|
|
// CheckClientName verifies if the client name is allowed to connect
|
|
CheckClientName CheckClientNameFunc
|
|
|
|
// CheckHost verifies if the client is allowed to connect to the remote host
|
|
CheckHost CheckHostFunc
|
|
|
|
// RedirectFlags sets what devices the client is allowed to redirect to the remote host
|
|
RedirectFlags RedirectFlags
|
|
|
|
// IdleTimeOut is used to determine when to disconnect clients that have been idle
|
|
IdleTimeout int
|
|
|
|
// SmartCardAuth sets whether to use smart card based authentication
|
|
SmartCardAuth bool
|
|
|
|
// TokenAuth sets whether to use token/cookie based authentication
|
|
TokenAuth bool
|
|
|
|
ReceiveBuf int
|
|
SendBuf int
|
|
}
|
|
|
|
var upgrader = websocket.Upgrader{}
|
|
var c = cache.New(5*time.Minute, 10*time.Minute)
|
|
|
|
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
|
connectionCache.Set(float64(c.ItemCount()))
|
|
|
|
var t *Tunnel
|
|
|
|
ctx := context.WithValue(r.Context(), common.TunnelCtx, t)
|
|
|
|
connId := r.Header.Get(rdgConnectionIdKey)
|
|
x, found := c.Get(connId)
|
|
if !found {
|
|
t = &Tunnel{
|
|
RDGId: connId,
|
|
RemoteAddr: ctx.Value(common.ClientIPCtx).(string),
|
|
UserName: ctx.Value(common.UsernameCtx).(string),
|
|
}
|
|
} else {
|
|
t = x.(*Tunnel)
|
|
}
|
|
|
|
if r.Method == MethodRDGOUT {
|
|
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
|
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
|
|
return
|
|
}
|
|
r.Method = "GET" // force
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("Cannot upgrade falling back to old protocol: %t", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
|
|
if err != nil {
|
|
log.Printf("Cannot set send/receive buffers: %t", err)
|
|
}
|
|
|
|
g.handleWebsocketProtocol(ctx, conn, t)
|
|
} else if r.Method == MethodRDGIN {
|
|
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
|
if g.SendBuf < 1 && g.ReceiveBuf < 1 {
|
|
return nil
|
|
}
|
|
|
|
// conn == tls.Tunnel
|
|
ptr := reflect.ValueOf(conn)
|
|
val := reflect.Indirect(ptr)
|
|
|
|
if val.Kind() != reflect.Struct {
|
|
return errors.New("didn't get a struct from conn")
|
|
}
|
|
|
|
// this gets net.Tunnel -> *net.TCPConn -> net.TCPConn
|
|
ptrConn := val.FieldByName("conn")
|
|
valConn := reflect.Indirect(ptrConn)
|
|
if !valConn.IsValid() {
|
|
return errors.New("cannot find conn field")
|
|
}
|
|
valConn = valConn.Elem().Elem()
|
|
|
|
// net.FD
|
|
ptrNetFd := valConn.FieldByName("fd")
|
|
valNetFd := reflect.Indirect(ptrNetFd)
|
|
if !valNetFd.IsValid() {
|
|
return errors.New("cannot find fd field")
|
|
}
|
|
|
|
// pfd member
|
|
ptrPfd := valNetFd.FieldByName("pfd")
|
|
valPfd := reflect.Indirect(ptrPfd)
|
|
if !valPfd.IsValid() {
|
|
return errors.New("cannot find pfd field")
|
|
}
|
|
|
|
// finally the exported Sysfd
|
|
ptrSysFd := valPfd.FieldByName("Sysfd")
|
|
if !ptrSysFd.IsValid() {
|
|
return errors.New("cannot find Sysfd field")
|
|
}
|
|
fd := int(ptrSysFd.Int())
|
|
|
|
if g.ReceiveBuf > 0 {
|
|
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf)
|
|
if err != nil {
|
|
return wrapSyscallError("setsockopt", err)
|
|
}
|
|
}
|
|
|
|
if g.SendBuf > 0 {
|
|
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.SendBuf)
|
|
if err != nil {
|
|
return wrapSyscallError("setsockopt", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, t *Tunnel) {
|
|
websocketConnections.Inc()
|
|
defer websocketConnections.Dec()
|
|
|
|
inout, _ := transport.NewWS(c)
|
|
defer inout.Close()
|
|
|
|
t.Id = uuid.New().String()
|
|
t.transportOut = inout
|
|
t.transportIn = inout
|
|
t.ConnectedOn = time.Now()
|
|
|
|
handler := NewProcessor(g, t)
|
|
RegisterTunnel(t, handler)
|
|
defer RemoveTunnel(t)
|
|
handler.Process(ctx)
|
|
}
|
|
|
|
// The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server
|
|
// and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different
|
|
// to ensure the connections do not get cached or terminated by a proxy prematurely.
|
|
func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) {
|
|
log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil)
|
|
|
|
if r.Method == MethodRDGOUT {
|
|
out, err := transport.NewLegacy(w)
|
|
if err != nil {
|
|
log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err)
|
|
return
|
|
}
|
|
log.Printf("Opening RDGOUT for client %s", common.GetClientIp(r.Context()))
|
|
|
|
t.transportOut = out
|
|
out.SendAccept(true)
|
|
|
|
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
|
} else if r.Method == MethodRDGIN {
|
|
legacyConnections.Inc()
|
|
defer legacyConnections.Dec()
|
|
|
|
in, err := transport.NewLegacy(w)
|
|
if err != nil {
|
|
log.Printf("cannot hijack connection to support RDG IN data channel: %s", err)
|
|
return
|
|
}
|
|
defer in.Close()
|
|
|
|
if t.transportIn == nil {
|
|
t.Id = uuid.New().String()
|
|
t.transportIn = in
|
|
c.Set(t.RDGId, t, cache.DefaultExpiration)
|
|
|
|
log.Printf("Opening RDGIN for client %s", common.GetClientIp(r.Context()))
|
|
in.SendAccept(false)
|
|
|
|
// read some initial data
|
|
in.Drain()
|
|
|
|
log.Printf("Legacy handshakeRequest done for client %s", common.GetClientIp(r.Context()))
|
|
handler := NewProcessor(g, t)
|
|
RegisterTunnel(t, handler)
|
|
defer RemoveTunnel(t)
|
|
handler.Process(r.Context())
|
|
}
|
|
}
|
|
}
|