Fix protocol fallback (#172)

* Fix protocol fallback

* Use token-aware header matching, drop dead fallback log

Connection is a list header; plain equality misses legitimate
"keep-alive, Upgrade" clients. Switch to case-insensitive token
matching for the Connection/Upgrade checks.

Remove the "falling back to old protocol" log on upgrade failure.
upgrader.Upgrade commits an HTTP error response before returning, so
the follow-up legacy path cannot produce a coherent reply. The real
fallback happens at the header pre-check for clients and reverse
proxies that strip the upgrade tokens.

Add tests for the header helper and RDGOUT routing.
This commit is contained in:
Darkar25
2026-04-23 20:35:16 +03:00
committed by GitHub
parent a93c9b438c
commit 4d58a9eb97
2 changed files with 183 additions and 16 deletions

View File

@@ -7,6 +7,7 @@ import (
"net"
"net/http"
"reflect"
"strings"
"syscall"
"time"
@@ -78,29 +79,49 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
ctx = context.WithValue(ctx, CtxTunnel, t)
if r.Method == MethodRDGOUT {
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
if headerHasToken(r.Header, "Connection", "upgrade") && headerHasToken(r.Header, "Upgrade", "websocket") {
r.Method = "GET" // upgrader requires GET
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
// Upgrade has already written an HTTP error response on the
// wire, so we cannot transparently fall back to the legacy
// protocol here. The header pre-check above handles the
// real-world fallback case: clients or reverse proxies that
// don't send the Upgrade/Connection tokens route to legacy
// without ever touching the upgrader.
log.Printf("cannot upgrade connection to websocket: %v", err)
return
}
defer conn.Close()
if err := g.setSendReceiveBuffers(conn.UnderlyingConn()); err != nil {
log.Printf("cannot set send/receive buffers: %v", err)
}
g.handleWebsocketProtocol(ctx, conn, t)
return
}
r.Method = "GET" // force
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Cannot upgrade falling back to old protocol: %t", err)
return
}
defer conn.Close()
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
if err != nil {
log.Printf("Cannot set send/receive buffers: %t", err)
}
g.handleWebsocketProtocol(ctx, conn, t)
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
} else if r.Method == MethodRDGIN {
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
}
}
// headerHasToken reports whether the HTTP header named by name contains
// token, matched case-insensitively against comma-separated tokens.
// Fields like Connection carry a list (e.g. "keep-alive, Upgrade") so a
// plain equality check on the raw value misses legitimate clients and
// well-behaved reverse proxies.
func headerHasToken(h http.Header, name, token string) bool {
for _, v := range h.Values(name) {
for _, t := range strings.Split(v, ",") {
if strings.EqualFold(strings.TrimSpace(t), token) {
return true
}
}
}
return false
}
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
if g.SendBuf < 1 && g.ReceiveBuf < 1 {
return nil

View File

@@ -0,0 +1,146 @@
package protocol
import (
"bufio"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
)
func TestHeaderHasToken(t *testing.T) {
cases := []struct {
name string
values []string
token string
want bool
}{
{"exact match", []string{"upgrade"}, "upgrade", true},
{"case insensitive value", []string{"Upgrade"}, "upgrade", true},
{"case insensitive token", []string{"upgrade"}, "UPGRADE", true},
{"token inside comma list", []string{"keep-alive, Upgrade"}, "upgrade", true},
{"token with surrounding whitespace", []string{" keep-alive ,\tupgrade\t"}, "upgrade", true},
{"not present", []string{"keep-alive"}, "upgrade", false},
{"empty header", nil, "upgrade", false},
{"substring must not match", []string{"upgrader"}, "upgrade", false},
{"multiple header values", []string{"keep-alive", "Upgrade"}, "upgrade", true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
h := http.Header{}
for _, v := range tc.values {
h.Add("X-Test", v)
}
if got := headerHasToken(h, "X-Test", tc.token); got != tc.want {
t.Errorf("headerHasToken(%v, %q) = %v, want %v", tc.values, tc.token, got, tc.want)
}
})
}
}
// newGatewayTestServer starts an httptest server that injects a minimal
// identity into the request context so HandleGatewayProtocol's downstream
// code can run without panicking.
func newGatewayTestServer(t *testing.T, gw *Gateway) *httptest.Server {
t.Helper()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := identity.NewUser()
id.SetAttribute(identity.AttrRemoteAddr, "127.0.0.1:0")
id.SetAttribute(identity.AttrClientIp, "127.0.0.1")
r = identity.AddToRequestCtx(id, r)
gw.HandleGatewayProtocol(w, r)
})
return httptest.NewServer(handler)
}
// TestHandleGatewayProtocolRouting drives HandleGatewayProtocol over a real
// TCP connection to verify the Connection/Upgrade header detection picks
// the right sub-handler. Using raw TCP is necessary because the legacy
// branch hijacks the connection and the websocket branch speaks 101
// Switching Protocols, neither of which plays nicely with net/http's
// standard client.
func TestHandleGatewayProtocolRouting(t *testing.T) {
srv := newGatewayTestServer(t, &Gateway{})
defer srv.Close()
addr := strings.TrimPrefix(srv.URL, "http://")
cases := []struct {
name string
request string
wantStatusLine string
}{
{
name: "RDG_OUT_DATA without upgrade headers routes to legacy",
request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" +
"Host: " + addr + "\r\n" +
"Rdg-Connection-Id: test-legacy\r\n" +
"\r\n",
wantStatusLine: "HTTP/1.1 200 OK",
},
{
name: "RDG_OUT_DATA with upgrade headers routes to websocket",
request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" +
"Host: " + addr + "\r\n" +
"Rdg-Connection-Id: test-ws\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: websocket\r\n" +
"Sec-WebSocket-Version: 13\r\n" +
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" +
"\r\n",
wantStatusLine: "HTTP/1.1 101 Switching Protocols",
},
{
name: "RDG_OUT_DATA with Connection token list still routes to websocket",
request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" +
"Host: " + addr + "\r\n" +
"Rdg-Connection-Id: test-ws-list\r\n" +
"Connection: keep-alive, Upgrade\r\n" +
"Upgrade: websocket\r\n" +
"Sec-WebSocket-Version: 13\r\n" +
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" +
"\r\n",
wantStatusLine: "HTTP/1.1 101 Switching Protocols",
},
{
name: "RDG_OUT_DATA with partially matching headers routes to legacy",
request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" +
"Host: " + addr + "\r\n" +
"Rdg-Connection-Id: test-partial\r\n" +
"Connection: Upgrade\r\n" +
"\r\n",
wantStatusLine: "HTTP/1.1 200 OK",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(3 * time.Second)); err != nil {
t.Fatalf("set deadline: %v", err)
}
if _, err := conn.Write([]byte(tc.request)); err != nil {
t.Fatalf("write: %v", err)
}
line, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
t.Fatalf("read status line: %v", err)
}
line = strings.TrimRight(line, "\r\n")
if line != tc.wantStatusLine {
t.Errorf("status line = %q, want %q", line, tc.wantStatusLine)
}
})
}
}