mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-03-28 06:56:34 +00:00
Check hostname specified by client against the token
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/bolkedebruin/rdpgw/transport"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@@ -13,9 +12,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type VerifyTunnelCreate func(*SessionInfo, string) (bool, error)
|
||||
type VerifyTunnelAuthFunc func(*SessionInfo, string) (bool, error)
|
||||
type VerifyServerFunc func(*SessionInfo, string) (bool, error)
|
||||
type VerifyTunnelCreate func(context.Context, string) (bool, error)
|
||||
type VerifyTunnelAuthFunc func(context.Context, string) (bool, error)
|
||||
type VerifyServerFunc func(context.Context, string) (bool, error)
|
||||
|
||||
type RedirectFlags struct {
|
||||
Clipboard bool
|
||||
@@ -29,8 +28,6 @@ type RedirectFlags struct {
|
||||
|
||||
type Handler struct {
|
||||
Session *SessionInfo
|
||||
TransportIn transport.Transport
|
||||
TransportOut transport.Transport
|
||||
VerifyTunnelCreate VerifyTunnelCreate
|
||||
VerifyTunnelAuthFunc VerifyTunnelAuthFunc
|
||||
VerifyServerFunc VerifyServerFunc
|
||||
@@ -55,10 +52,8 @@ type HandlerConf struct {
|
||||
|
||||
func NewHandler(s *SessionInfo, conf *HandlerConf) *Handler {
|
||||
h := &Handler{
|
||||
State: SERVER_STATE_INITIAL,
|
||||
Session: s,
|
||||
TransportIn: s.TransportIn,
|
||||
TransportOut: s.TransportOut,
|
||||
State: SERVER_STATE_INITIAL,
|
||||
RedirectFlags: makeRedirectFlags(conf.RedirectFlags),
|
||||
IdleTimeout: conf.IdleTimeout,
|
||||
SmartCardAuth: conf.SmartCardAuth,
|
||||
@@ -89,7 +84,7 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||
}
|
||||
major, minor, _, _ := readHandshake(pkt) // todo check if auth matches what the handler can do
|
||||
msg := h.handshakeResponse(major, minor)
|
||||
h.TransportOut.WritePacket(msg)
|
||||
h.Session.TransportOut.WritePacket(msg)
|
||||
h.State = SERVER_STATE_HANDSHAKE
|
||||
case PKT_TYPE_TUNNEL_CREATE:
|
||||
log.Printf("Tunnel create")
|
||||
@@ -100,13 +95,13 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||
}
|
||||
_, cookie := readCreateTunnelRequest(pkt)
|
||||
if h.VerifyTunnelCreate != nil {
|
||||
if ok, _ := h.VerifyTunnelCreate(h.Session, cookie); !ok {
|
||||
if ok, _ := h.VerifyTunnelCreate(ctx, cookie); !ok {
|
||||
log.Printf("Invalid PAA cookie received")
|
||||
return errors.New("invalid PAA cookie")
|
||||
}
|
||||
}
|
||||
msg := createTunnelResponse()
|
||||
h.TransportOut.WritePacket(msg)
|
||||
h.Session.TransportOut.WritePacket(msg)
|
||||
h.State = SERVER_STATE_TUNNEL_CREATE
|
||||
case PKT_TYPE_TUNNEL_AUTH:
|
||||
log.Printf("Tunnel auth")
|
||||
@@ -117,13 +112,13 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||
}
|
||||
client := h.readTunnelAuthRequest(pkt)
|
||||
if h.VerifyTunnelAuthFunc != nil {
|
||||
if ok, _ := h.VerifyTunnelAuthFunc(h.Session, client); !ok {
|
||||
if ok, _ := h.VerifyTunnelAuthFunc(ctx, client); !ok {
|
||||
log.Printf("Invalid client name: %s", client)
|
||||
return errors.New("invalid client name")
|
||||
}
|
||||
}
|
||||
msg := h.createTunnelAuthResponse()
|
||||
h.TransportOut.WritePacket(msg)
|
||||
h.Session.TransportOut.WritePacket(msg)
|
||||
h.State = SERVER_STATE_TUNNEL_AUTHORIZE
|
||||
case PKT_TYPE_CHANNEL_CREATE:
|
||||
log.Printf("Channel create")
|
||||
@@ -135,8 +130,9 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||
server, port := readChannelCreateRequest(pkt)
|
||||
host := net.JoinHostPort(server, strconv.Itoa(int(port)))
|
||||
if h.VerifyServerFunc != nil {
|
||||
if ok, _ := h.VerifyServerFunc(h.Session, host); !ok {
|
||||
if ok, _ := h.VerifyServerFunc(ctx, host); !ok {
|
||||
log.Printf("Not allowed to connect to %s by policy handler", host)
|
||||
return errors.New("denied by security policy")
|
||||
}
|
||||
}
|
||||
log.Printf("Establishing connection to RDP server: %s", host)
|
||||
@@ -147,7 +143,7 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||
}
|
||||
log.Printf("Connection established")
|
||||
msg := createChannelCreateResponse()
|
||||
h.TransportOut.WritePacket(msg)
|
||||
h.Session.TransportOut.WritePacket(msg)
|
||||
|
||||
// Make sure to start the flow from the RDP server first otherwise connections
|
||||
// might hang eventually
|
||||
@@ -175,8 +171,8 @@ func (h *Handler) Process(ctx context.Context) error {
|
||||
log.Printf("Channel closed while in wrong state %d != %d", h.State, SERVER_STATE_OPENED)
|
||||
return errors.New("wrong state")
|
||||
}
|
||||
h.TransportIn.Close()
|
||||
h.TransportOut.Close()
|
||||
h.Session.TransportIn.Close()
|
||||
h.Session.TransportOut.Close()
|
||||
h.State = SERVER_STATE_CLOSED
|
||||
default:
|
||||
log.Printf("Unknown packet (size %d): %x", sz, pkt)
|
||||
@@ -190,7 +186,7 @@ func (h *Handler) ReadMessage() (pt int, n int, msg []byte, err error) {
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
size, pkt, err := h.TransportIn.ReadPacket()
|
||||
size, pkt, err := h.Session.TransportIn.ReadPacket()
|
||||
if err != nil {
|
||||
return 0, 0, []byte{0, 0}, err
|
||||
}
|
||||
@@ -398,7 +394,7 @@ func (h *Handler) sendDataPacket() {
|
||||
break
|
||||
}
|
||||
b1.Write(buf[:n])
|
||||
h.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
h.Session.TransportOut.WritePacket(createPacket(PKT_TYPE_DATA, b1.Bytes()))
|
||||
b1.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,13 +46,11 @@ type Gateway struct {
|
||||
|
||||
type SessionInfo struct {
|
||||
ConnId string
|
||||
CorrelationId string
|
||||
ClientGeneration string
|
||||
TransportIn transport.Transport
|
||||
TransportOut transport.Transport
|
||||
RemoteAddress string
|
||||
ProxyAddresses string
|
||||
UserName string
|
||||
ProxyAddress string
|
||||
RemoteServer string
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
@@ -65,9 +63,6 @@ func init() {
|
||||
}
|
||||
|
||||
func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
connectionCache.Set(float64(c.ItemCount()))
|
||||
|
||||
var s *SessionInfo
|
||||
@@ -79,6 +74,7 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
||||
} else {
|
||||
s = x.(*SessionInfo)
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), "SessionInfo", s)
|
||||
|
||||
if r.Method == MethodRDGOUT {
|
||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
||||
|
||||
Reference in New Issue
Block a user