diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go index 676387f..c1f6869 100644 --- a/cmd/rdpgw/protocol/gateway.go +++ b/cmd/rdpgw/protocol/gateway.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "reflect" + "strings" "syscall" "time" @@ -78,29 +79,49 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) ctx = context.WithValue(ctx, CtxTunnel, t) if r.Method == MethodRDGOUT { - if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { - g.handleLegacyProtocol(w, r.WithContext(ctx), t) + if headerHasToken(r.Header, "Connection", "upgrade") && headerHasToken(r.Header, "Upgrade", "websocket") { + r.Method = "GET" // upgrader requires GET + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + // Upgrade has already written an HTTP error response on the + // wire, so we cannot transparently fall back to the legacy + // protocol here. The header pre-check above handles the + // real-world fallback case: clients or reverse proxies that + // don't send the Upgrade/Connection tokens route to legacy + // without ever touching the upgrader. + log.Printf("cannot upgrade connection to websocket: %v", err) + return + } + defer conn.Close() + + if err := g.setSendReceiveBuffers(conn.UnderlyingConn()); err != nil { + log.Printf("cannot set send/receive buffers: %v", err) + } + g.handleWebsocketProtocol(ctx, conn, t) return } - r.Method = "GET" // force - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Printf("Cannot upgrade falling back to old protocol: %t", err) - return - } - defer conn.Close() - - err = g.setSendReceiveBuffers(conn.UnderlyingConn()) - if err != nil { - log.Printf("Cannot set send/receive buffers: %t", err) - } - - g.handleWebsocketProtocol(ctx, conn, t) + g.handleLegacyProtocol(w, r.WithContext(ctx), t) } else if r.Method == MethodRDGIN { g.handleLegacyProtocol(w, r.WithContext(ctx), t) } } +// headerHasToken reports whether the HTTP header named by name contains +// token, matched case-insensitively against comma-separated tokens. +// Fields like Connection carry a list (e.g. "keep-alive, Upgrade") so a +// plain equality check on the raw value misses legitimate clients and +// well-behaved reverse proxies. +func headerHasToken(h http.Header, name, token string) bool { + for _, v := range h.Values(name) { + for _, t := range strings.Split(v, ",") { + if strings.EqualFold(strings.TrimSpace(t), token) { + return true + } + } + } + return false +} + func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { if g.SendBuf < 1 && g.ReceiveBuf < 1 { return nil diff --git a/cmd/rdpgw/protocol/gateway_test.go b/cmd/rdpgw/protocol/gateway_test.go new file mode 100644 index 0000000..fd288de --- /dev/null +++ b/cmd/rdpgw/protocol/gateway_test.go @@ -0,0 +1,146 @@ +package protocol + +import ( + "bufio" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" +) + +func TestHeaderHasToken(t *testing.T) { + cases := []struct { + name string + values []string + token string + want bool + }{ + {"exact match", []string{"upgrade"}, "upgrade", true}, + {"case insensitive value", []string{"Upgrade"}, "upgrade", true}, + {"case insensitive token", []string{"upgrade"}, "UPGRADE", true}, + {"token inside comma list", []string{"keep-alive, Upgrade"}, "upgrade", true}, + {"token with surrounding whitespace", []string{" keep-alive ,\tupgrade\t"}, "upgrade", true}, + {"not present", []string{"keep-alive"}, "upgrade", false}, + {"empty header", nil, "upgrade", false}, + {"substring must not match", []string{"upgrader"}, "upgrade", false}, + {"multiple header values", []string{"keep-alive", "Upgrade"}, "upgrade", true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + h := http.Header{} + for _, v := range tc.values { + h.Add("X-Test", v) + } + if got := headerHasToken(h, "X-Test", tc.token); got != tc.want { + t.Errorf("headerHasToken(%v, %q) = %v, want %v", tc.values, tc.token, got, tc.want) + } + }) + } +} + +// newGatewayTestServer starts an httptest server that injects a minimal +// identity into the request context so HandleGatewayProtocol's downstream +// code can run without panicking. +func newGatewayTestServer(t *testing.T, gw *Gateway) *httptest.Server { + t.Helper() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := identity.NewUser() + id.SetAttribute(identity.AttrRemoteAddr, "127.0.0.1:0") + id.SetAttribute(identity.AttrClientIp, "127.0.0.1") + r = identity.AddToRequestCtx(id, r) + gw.HandleGatewayProtocol(w, r) + }) + return httptest.NewServer(handler) +} + +// TestHandleGatewayProtocolRouting drives HandleGatewayProtocol over a real +// TCP connection to verify the Connection/Upgrade header detection picks +// the right sub-handler. Using raw TCP is necessary because the legacy +// branch hijacks the connection and the websocket branch speaks 101 +// Switching Protocols, neither of which plays nicely with net/http's +// standard client. +func TestHandleGatewayProtocolRouting(t *testing.T) { + srv := newGatewayTestServer(t, &Gateway{}) + defer srv.Close() + + addr := strings.TrimPrefix(srv.URL, "http://") + + cases := []struct { + name string + request string + wantStatusLine string + }{ + { + name: "RDG_OUT_DATA without upgrade headers routes to legacy", + request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" + + "Host: " + addr + "\r\n" + + "Rdg-Connection-Id: test-legacy\r\n" + + "\r\n", + wantStatusLine: "HTTP/1.1 200 OK", + }, + { + name: "RDG_OUT_DATA with upgrade headers routes to websocket", + request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" + + "Host: " + addr + "\r\n" + + "Rdg-Connection-Id: test-ws\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: websocket\r\n" + + "Sec-WebSocket-Version: 13\r\n" + + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + + "\r\n", + wantStatusLine: "HTTP/1.1 101 Switching Protocols", + }, + { + name: "RDG_OUT_DATA with Connection token list still routes to websocket", + request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" + + "Host: " + addr + "\r\n" + + "Rdg-Connection-Id: test-ws-list\r\n" + + "Connection: keep-alive, Upgrade\r\n" + + "Upgrade: websocket\r\n" + + "Sec-WebSocket-Version: 13\r\n" + + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + + "\r\n", + wantStatusLine: "HTTP/1.1 101 Switching Protocols", + }, + { + name: "RDG_OUT_DATA with partially matching headers routes to legacy", + request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" + + "Host: " + addr + "\r\n" + + "Rdg-Connection-Id: test-partial\r\n" + + "Connection: Upgrade\r\n" + + "\r\n", + wantStatusLine: "HTTP/1.1 200 OK", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + conn, err := net.DialTimeout("tcp", addr, 2*time.Second) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + if err := conn.SetDeadline(time.Now().Add(3 * time.Second)); err != nil { + t.Fatalf("set deadline: %v", err) + } + if _, err := conn.Write([]byte(tc.request)); err != nil { + t.Fatalf("write: %v", err) + } + + line, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + t.Fatalf("read status line: %v", err) + } + line = strings.TrimRight(line, "\r\n") + if line != tc.wantStatusLine { + t.Errorf("status line = %q, want %q", line, tc.wantStatusLine) + } + }) + } +}