mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-05-19 14:50:02 +00:00
Fix protocol fallback (#172)
* Fix protocol fallback * Use token-aware header matching, drop dead fallback log Connection is a list header; plain equality misses legitimate "keep-alive, Upgrade" clients. Switch to case-insensitive token matching for the Connection/Upgrade checks. Remove the "falling back to old protocol" log on upgrade failure. upgrader.Upgrade commits an HTTP error response before returning, so the follow-up legacy path cannot produce a coherent reply. The real fallback happens at the header pre-check for clients and reverse proxies that strip the upgrade tokens. Add tests for the header helper and RDGOUT routing.
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -78,29 +79,49 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request)
|
|||||||
ctx = context.WithValue(ctx, CtxTunnel, t)
|
ctx = context.WithValue(ctx, CtxTunnel, t)
|
||||||
|
|
||||||
if r.Method == MethodRDGOUT {
|
if r.Method == MethodRDGOUT {
|
||||||
if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" {
|
if headerHasToken(r.Header, "Connection", "upgrade") && headerHasToken(r.Header, "Upgrade", "websocket") {
|
||||||
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
|
r.Method = "GET" // upgrader requires GET
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
// Upgrade has already written an HTTP error response on the
|
||||||
|
// wire, so we cannot transparently fall back to the legacy
|
||||||
|
// protocol here. The header pre-check above handles the
|
||||||
|
// real-world fallback case: clients or reverse proxies that
|
||||||
|
// don't send the Upgrade/Connection tokens route to legacy
|
||||||
|
// without ever touching the upgrader.
|
||||||
|
log.Printf("cannot upgrade connection to websocket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if err := g.setSendReceiveBuffers(conn.UnderlyingConn()); err != nil {
|
||||||
|
log.Printf("cannot set send/receive buffers: %v", err)
|
||||||
|
}
|
||||||
|
g.handleWebsocketProtocol(ctx, conn, t)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.Method = "GET" // force
|
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Cannot upgrade falling back to old protocol: %t", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
err = g.setSendReceiveBuffers(conn.UnderlyingConn())
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Cannot set send/receive buffers: %t", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
g.handleWebsocketProtocol(ctx, conn, t)
|
|
||||||
} else if r.Method == MethodRDGIN {
|
} else if r.Method == MethodRDGIN {
|
||||||
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
|
g.handleLegacyProtocol(w, r.WithContext(ctx), t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// headerHasToken reports whether the HTTP header named by name contains
|
||||||
|
// token, matched case-insensitively against comma-separated tokens.
|
||||||
|
// Fields like Connection carry a list (e.g. "keep-alive, Upgrade") so a
|
||||||
|
// plain equality check on the raw value misses legitimate clients and
|
||||||
|
// well-behaved reverse proxies.
|
||||||
|
func headerHasToken(h http.Header, name, token string) bool {
|
||||||
|
for _, v := range h.Values(name) {
|
||||||
|
for _, t := range strings.Split(v, ",") {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(t), token) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
|
||||||
if g.SendBuf < 1 && g.ReceiveBuf < 1 {
|
if g.SendBuf < 1 && g.ReceiveBuf < 1 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
146
cmd/rdpgw/protocol/gateway_test.go
Normal file
146
cmd/rdpgw/protocol/gateway_test.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHeaderHasToken(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
values []string
|
||||||
|
token string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"exact match", []string{"upgrade"}, "upgrade", true},
|
||||||
|
{"case insensitive value", []string{"Upgrade"}, "upgrade", true},
|
||||||
|
{"case insensitive token", []string{"upgrade"}, "UPGRADE", true},
|
||||||
|
{"token inside comma list", []string{"keep-alive, Upgrade"}, "upgrade", true},
|
||||||
|
{"token with surrounding whitespace", []string{" keep-alive ,\tupgrade\t"}, "upgrade", true},
|
||||||
|
{"not present", []string{"keep-alive"}, "upgrade", false},
|
||||||
|
{"empty header", nil, "upgrade", false},
|
||||||
|
{"substring must not match", []string{"upgrader"}, "upgrade", false},
|
||||||
|
{"multiple header values", []string{"keep-alive", "Upgrade"}, "upgrade", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
h := http.Header{}
|
||||||
|
for _, v := range tc.values {
|
||||||
|
h.Add("X-Test", v)
|
||||||
|
}
|
||||||
|
if got := headerHasToken(h, "X-Test", tc.token); got != tc.want {
|
||||||
|
t.Errorf("headerHasToken(%v, %q) = %v, want %v", tc.values, tc.token, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newGatewayTestServer starts an httptest server that injects a minimal
|
||||||
|
// identity into the request context so HandleGatewayProtocol's downstream
|
||||||
|
// code can run without panicking.
|
||||||
|
func newGatewayTestServer(t *testing.T, gw *Gateway) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id := identity.NewUser()
|
||||||
|
id.SetAttribute(identity.AttrRemoteAddr, "127.0.0.1:0")
|
||||||
|
id.SetAttribute(identity.AttrClientIp, "127.0.0.1")
|
||||||
|
r = identity.AddToRequestCtx(id, r)
|
||||||
|
gw.HandleGatewayProtocol(w, r)
|
||||||
|
})
|
||||||
|
return httptest.NewServer(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleGatewayProtocolRouting drives HandleGatewayProtocol over a real
|
||||||
|
// TCP connection to verify the Connection/Upgrade header detection picks
|
||||||
|
// the right sub-handler. Using raw TCP is necessary because the legacy
|
||||||
|
// branch hijacks the connection and the websocket branch speaks 101
|
||||||
|
// Switching Protocols, neither of which plays nicely with net/http's
|
||||||
|
// standard client.
|
||||||
|
func TestHandleGatewayProtocolRouting(t *testing.T) {
|
||||||
|
srv := newGatewayTestServer(t, &Gateway{})
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
addr := strings.TrimPrefix(srv.URL, "http://")
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
request string
|
||||||
|
wantStatusLine string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RDG_OUT_DATA without upgrade headers routes to legacy",
|
||||||
|
request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" +
|
||||||
|
"Host: " + addr + "\r\n" +
|
||||||
|
"Rdg-Connection-Id: test-legacy\r\n" +
|
||||||
|
"\r\n",
|
||||||
|
wantStatusLine: "HTTP/1.1 200 OK",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RDG_OUT_DATA with upgrade headers routes to websocket",
|
||||||
|
request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" +
|
||||||
|
"Host: " + addr + "\r\n" +
|
||||||
|
"Rdg-Connection-Id: test-ws\r\n" +
|
||||||
|
"Connection: Upgrade\r\n" +
|
||||||
|
"Upgrade: websocket\r\n" +
|
||||||
|
"Sec-WebSocket-Version: 13\r\n" +
|
||||||
|
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" +
|
||||||
|
"\r\n",
|
||||||
|
wantStatusLine: "HTTP/1.1 101 Switching Protocols",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RDG_OUT_DATA with Connection token list still routes to websocket",
|
||||||
|
request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" +
|
||||||
|
"Host: " + addr + "\r\n" +
|
||||||
|
"Rdg-Connection-Id: test-ws-list\r\n" +
|
||||||
|
"Connection: keep-alive, Upgrade\r\n" +
|
||||||
|
"Upgrade: websocket\r\n" +
|
||||||
|
"Sec-WebSocket-Version: 13\r\n" +
|
||||||
|
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" +
|
||||||
|
"\r\n",
|
||||||
|
wantStatusLine: "HTTP/1.1 101 Switching Protocols",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RDG_OUT_DATA with partially matching headers routes to legacy",
|
||||||
|
request: "RDG_OUT_DATA /remoteDesktopGateway/ HTTP/1.1\r\n" +
|
||||||
|
"Host: " + addr + "\r\n" +
|
||||||
|
"Rdg-Connection-Id: test-partial\r\n" +
|
||||||
|
"Connection: Upgrade\r\n" +
|
||||||
|
"\r\n",
|
||||||
|
wantStatusLine: "HTTP/1.1 200 OK",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Now().Add(3 * time.Second)); err != nil {
|
||||||
|
t.Fatalf("set deadline: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := conn.Write([]byte(tc.request)); err != nil {
|
||||||
|
t.Fatalf("write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := bufio.NewReader(conn).ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read status line: %v", err)
|
||||||
|
}
|
||||||
|
line = strings.TrimRight(line, "\r\n")
|
||||||
|
if line != tc.wantStatusLine {
|
||||||
|
t.Errorf("status line = %q, want %q", line, tc.wantStatusLine)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user