diff --git a/flake.nix b/flake.nix index b388cc9..78d0291 100644 --- a/flake.nix +++ b/flake.nix @@ -25,7 +25,7 @@ inherit (pkgs) lib; # Update version when releasing - version = "1.12.4"; + version = "1.12.5"; in { default = self.packages.${system}.pangolin-newt; diff --git a/netstack2/http_handler.go b/netstack2/http_handler.go index 7ba2f63..ece82e9 100644 --- a/netstack2/http_handler.go +++ b/netstack2/http_handler.go @@ -131,18 +131,23 @@ func (l *chanListener) send(conn net.Conn) bool { // httpConnCtx – conn wrapper that carries a SubnetRule through the listener // --------------------------------------------------------------------------- -// httpConnCtx wraps a net.Conn so the matching SubnetRule can be passed -// through the chanListener into the http.Server's ConnContext callback, -// making it available to request handlers via the request context. +// httpConnCtx wraps a net.Conn so the matching SubnetRule and TLS state can +// be passed through the chanListener into the http.Server's ConnContext +// callback, making them available to request handlers via the request context. type httpConnCtx struct { net.Conn - rule *SubnetRule + rule *SubnetRule + isTLS bool // true when the conn was wrapped with tls.Server } // connCtxKey is the unexported context key used to store a *SubnetRule on the // per-connection context created by http.Server.ConnContext. type connCtxKey struct{} +// connTLSKey is the unexported context key used to store the isTLS flag on +// the per-connection context created by http.Server.ConnContext. +type connTLSKey struct{} + // --------------------------------------------------------------------------- // Constructor and lifecycle // --------------------------------------------------------------------------- @@ -175,7 +180,8 @@ func (h *HTTPHandler) Start() error { // that handleRequest can retrieve it without any global state. ConnContext: func(ctx context.Context, c net.Conn) context.Context { if cc, ok := c.(*httpConnCtx); ok { - return context.WithValue(ctx, connCtxKey{}, cc.rule) + ctx = context.WithValue(ctx, connCtxKey{}, cc.rule) + ctx = context.WithValue(ctx, connTLSKey{}, cc.isTLS) } return ctx }, @@ -203,19 +209,28 @@ func (h *HTTPHandler) HandleConn(conn net.Conn, rule *SubnetRule) { var effectiveConn net.Conn = conn if rule.Protocol == "https" { - tlsCfg, err := h.getTLSConfig(rule) - if err != nil { - logger.Error("HTTP handler: cannot build TLS config for connection from %s: %v", - conn.RemoteAddr(), err) - conn.Close() - return + // Only perform TLS termination for connections arriving on port 443. + // Connections on port 80 are passed through as plain HTTP so that + // handleRequest can issue the HTTP→HTTPS redirect. + doTLS := false + if tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr); ok { + doTLS = tcpAddr.Port == 443 + } + if doTLS { + tlsCfg, err := h.getTLSConfig(rule) + if err != nil { + logger.Error("HTTP handler: cannot build TLS config for connection from %s: %v", + conn.RemoteAddr(), err) + conn.Close() + return + } + // tls.Server wraps the raw conn; the TLS handshake is deferred until + // the first Read, which the http.Server will trigger naturally. + effectiveConn = tls.Server(conn, tlsCfg) } - // tls.Server wraps the raw conn; the TLS handshake is deferred until - // the first Read, which the http.Server will trigger naturally. - effectiveConn = tls.Server(conn, tlsCfg) } - wrapped := &httpConnCtx{Conn: effectiveConn, rule: rule} + wrapped := &httpConnCtx{Conn: effectiveConn, rule: rule, isTLS: effectiveConn != conn} if !h.listener.send(wrapped) { // Listener is already closed — clean up the orphaned connection. effectiveConn.Close() @@ -291,6 +306,9 @@ func (h *HTTPHandler) getProxy(target HTTPTarget) *httputil.ReverseProxy { proxy := &httputil.ReverseProxy{ Rewrite: func(pr *httputil.ProxyRequest) { pr.SetURL(targetURL) + if host := pr.In.Host; host != "" { + pr.Out.Host = host + } // SetXForwarded sets X-Forwarded-For from the inbound request's // RemoteAddr (the WireGuard/netstack client address), along with // X-Forwarded-Host and X-Forwarded-Proto. Using Rewrite instead of @@ -356,9 +374,13 @@ func (h *HTTPHandler) handleRequest(w http.ResponseWriter, r *http.Request) { return } - // If the rule is HTTPS and a TLS certificate is configured, but the - // incoming request arrived over plain HTTP, redirect to HTTPS. - if rule.Protocol == "https" && rule.TLSCert != "" && rule.TLSKey != "" && r.TLS == nil { + // If the rule is HTTPS but the incoming request arrived over plain HTTP + // (port 80), redirect to HTTPS. We use the isTLS flag stored on the + // connection context rather than r.TLS, because Go's http.Server calls + // ConnectionState() before the TLS handshake completes, so r.TLS.Version + // is 0 even for genuine TLS connections at that point. + isTLS, _ := r.Context().Value(connTLSKey{}).(bool) + if rule.Protocol == "https" && !isTLS { host := r.Host if host == "" { host = r.URL.Host diff --git a/netstack2/http_handler_tls_test.go b/netstack2/http_handler_tls_test.go new file mode 100644 index 0000000..0f2ffdc --- /dev/null +++ b/netstack2/http_handler_tls_test.go @@ -0,0 +1,48 @@ +package netstack2 + +import ( + "crypto/tls" + "net" + "testing" +) + +// tlsConnStub is a minimal net.Conn that also exposes TLS state, matching +// *tls.Conn's ConnectionState used by net/http.Server. +type tlsConnStub struct { + net.Conn + state tls.ConnectionState +} + +func (t *tlsConnStub) ConnectionState() tls.ConnectionState { + return t.state +} + +func TestHTTPConnCtxForwardsConnectionState(t *testing.T) { + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + + inner := &tlsConnStub{ + Conn: c1, + state: tls.ConnectionState{Version: tls.VersionTLS12, HandshakeComplete: true}, + } + wrapped := &httpConnCtx{Conn: inner, rule: nil} + + got := wrapped.ConnectionState() + if got.Version != tls.VersionTLS12 || !got.HandshakeComplete { + t.Fatalf("ConnectionState = %+v, want TLS 1.2 and HandshakeComplete", got) + } +} + +func TestHTTPConnCtxConnectionStatePlainTCP(t *testing.T) { + c1, c2 := net.Pipe() + defer c1.Close() + defer c2.Close() + + wrapped := &httpConnCtx{Conn: c1, rule: nil} + got := wrapped.ConnectionState() + if got.Version != 0 { + t.Fatalf("expected zero ConnectionState for plain conn, got %+v", got) + } + _ = c2 +}