mirror of
https://github.com/fosrl/newt.git
synced 2026-05-14 12:19:53 +00:00
Support websocket upgrades in private HTTP proxy
Preserve optional ResponseWriter interfaces through statusCapture so httputil.ReverseProxy can hijack upgraded websocket connections. Add a regression test covering websocket traffic through the HTTP handler path.
This commit is contained in:
@@ -6,8 +6,10 @@
|
|||||||
package netstack2
|
package netstack2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -29,7 +31,7 @@ import (
|
|||||||
type HTTPTarget struct {
|
type HTTPTarget struct {
|
||||||
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
|
DestAddr string `json:"destAddr"` // IP address or hostname of the downstream service
|
||||||
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
|
DestPort uint16 `json:"destPort"` // TCP port of the downstream service
|
||||||
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
|
Scheme string `json:"scheme"` // When true the outbound leg uses HTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -322,6 +324,24 @@ func (sc *statusCapture) WriteHeader(code int) {
|
|||||||
sc.ResponseWriter.WriteHeader(code)
|
sc.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sc *statusCapture) Unwrap() http.ResponseWriter {
|
||||||
|
return sc.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *statusCapture) Flush() {
|
||||||
|
if flusher, ok := sc.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *statusCapture) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hijacker, ok := sc.ResponseWriter.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.New("underlying response writer does not support hijacking")
|
||||||
|
}
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
|
// handleRequest is the http.Handler entry point. It retrieves the SubnetRule
|
||||||
// attached to the connection by ConnContext, selects the first configured
|
// attached to the connection by ConnContext, selects the first configured
|
||||||
// downstream target, and forwards the request via the cached ReverseProxy.
|
// downstream target, and forwards the request via the cached ReverseProxy.
|
||||||
|
|||||||
97
netstack2/http_handler_test.go
Normal file
97
netstack2/http_handler_test.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package netstack2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHTTPHandlerProxiesWebSocketUpgrade(t *testing.T) {
|
||||||
|
upgrader := websocket.Upgrader{}
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("upgrade failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
messageType, payload, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := conn.WriteMessage(messageType, append([]byte("echo:"), payload...)); err != nil {
|
||||||
|
t.Errorf("write failed: %v", err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
backendURL, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse backend URL: %v", err)
|
||||||
|
}
|
||||||
|
backendHost, backendPort, err := net.SplitHostPort(backendURL.Host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("split backend host: %v", err)
|
||||||
|
}
|
||||||
|
port, err := net.LookupPort("tcp", backendPort)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse backend port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := NewHTTPHandler(nil, nil)
|
||||||
|
rule := &SubnetRule{
|
||||||
|
Protocol: "http",
|
||||||
|
HTTPTargets: []HTTPTarget{
|
||||||
|
{
|
||||||
|
DestAddr: backendHost,
|
||||||
|
DestPort: uint16(port),
|
||||||
|
Scheme: backendURL.Scheme,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.WithValue(r.Context(), connCtxKey{}, rule)
|
||||||
|
handler.handleRequest(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
defer frontend.Close()
|
||||||
|
|
||||||
|
frontendURL, err := url.Parse(frontend.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse frontend URL: %v", err)
|
||||||
|
}
|
||||||
|
wsURL := url.URL{
|
||||||
|
Scheme: "ws",
|
||||||
|
Host: frontendURL.Host,
|
||||||
|
Path: "/socket",
|
||||||
|
RawQuery: "token=test",
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial websocket through proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello")); err != nil {
|
||||||
|
t.Fatalf("write websocket message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
messageType, payload, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read websocket message: %v", err)
|
||||||
|
}
|
||||||
|
if messageType != websocket.TextMessage {
|
||||||
|
t.Fatalf("message type = %d, want %d", messageType, websocket.TextMessage)
|
||||||
|
}
|
||||||
|
if got, want := string(payload), "echo:hello"; got != want {
|
||||||
|
t.Fatalf("payload = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user