package inspect import ( "bufio" "context" "fmt" "io" "net" "net/http" "net/netip" "strings" "sync" "time" "github.com/netbirdio/netbird/shared/management/domain" ) const ( headerUpgrade = "Upgrade" valueWebSocket = "websocket" ) // inspectHTTP runs the HTTP inspection pipeline on decrypted traffic. // It handles HTTP/1.1 (request-response loop), HTTP/2 (via Go stdlib reverse proxy), // and WebSocket upgrade detection. func (p *Proxy) inspectHTTP(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, proto string) error { if proto == "h2" { return p.inspectH2(ctx, client, remote, dst, sni, src) } return p.inspectH1(ctx, client, remote, dst, sni, src) } // inspectH1 handles HTTP/1.1 request-response inspection in a loop. func (p *Proxy) inspectH1(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error { clientReader := bufio.NewReader(client) remoteReader := bufio.NewReader(remote) for { if ctx.Err() != nil { return ctx.Err() } // Set idle timeout between requests to prevent connection hogging. if err := client.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { return fmt.Errorf("set idle deadline: %w", err) } req, err := http.ReadRequest(clientReader) if err != nil { if isClosedErr(err) { return nil } return fmt.Errorf("read HTTP request: %w", err) } if err := client.SetReadDeadline(time.Time{}); err != nil { return fmt.Errorf("clear read deadline: %w", err) } // Re-evaluate rules based on Host header if SNI was empty host := hostFromRequest(req, sni) // Domain fronting: Host header doesn't match TLS SNI if isDomainFronting(req, sni) { p.log.Debugf("domain fronting detected: SNI=%s Host=%s", sni.PunycodeString(), host.PunycodeString()) writeBlockResponse(client, req, host) return ErrBlocked } proto := ProtoHTTP if isWebSocketUpgrade(req) { proto = ProtoWebSocket } action := p.evaluateAction(src.IP, host, dst, proto, req.URL.Path) if action == ActionBlock { p.log.Debugf("block: HTTP %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString()) writeBlockResponse(client, req, host) return ErrBlocked } p.log.Tracef("allow: HTTP %s %s (host=%s, action=%s)", req.Method, req.URL.Path, host.PunycodeString(), action) // ICAP REQMOD: send request for inspection. // Snapshot ICAP client under lock to avoid use-after-close races. p.mu.RLock() icap := p.icap p.mu.RUnlock() if icap != nil { modified, err := icap.ReqMod(req) if err != nil { p.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err) // Fail-closed: block on ICAP error writeBlockResponse(client, req, host) return fmt.Errorf("ICAP REQMOD: %w", err) } req = modified } if isWebSocketUpgrade(req) { return p.handleWebSocket(ctx, req, client, clientReader, remote, remoteReader) } removeHopByHopHeaders(req.Header) if err := req.Write(remote); err != nil { return fmt.Errorf("forward request: %w", err) } resp, err := http.ReadResponse(remoteReader, req) if err != nil { return fmt.Errorf("read HTTP response: %w", err) } // ICAP RESPMOD: send response for inspection if icap != nil { modified, err := icap.RespMod(req, resp) if err != nil { p.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err) if err := resp.Body.Close(); err != nil { p.log.Debugf("close resp body: %v", err) } writeBlockResponse(client, req, host) return fmt.Errorf("ICAP RESPMOD: %w", err) } resp = modified } removeHopByHopHeaders(resp.Header) if err := resp.Write(client); err != nil { if closeErr := resp.Body.Close(); closeErr != nil { p.log.Debugf("close resp body: %v", closeErr) } return fmt.Errorf("forward response: %w", err) } if err := resp.Body.Close(); err != nil { p.log.Debugf("close resp body: %v", err) } // Connection: close means we're done if resp.Close || req.Close { return nil } } } // inspectH2 proxies HTTP/2 traffic using Go's http stack. // Client and remote are already-established TLS connections with h2 negotiated. func (p *Proxy) inspectH2(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error { // For h2 MITM inspection, we use a local http.Server reading from the client // connection and an http.Transport writing to the remote connection. // // The transport is configured to use the existing TLS connection to the // real server. The handler inspects each request/response pair. transport := &http.Transport{ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return remote, nil }, DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) { return remote, nil }, ForceAttemptHTTP2: true, } handler := &h2InspectionHandler{ proxy: p, transport: transport, dst: dst, sni: sni, src: src, } server := &http.Server{ Handler: handler, } // Serve the single client connection. // ServeConn blocks until the connection is done. errCh := make(chan error, 1) go func() { // http.Server doesn't have a direct ServeConn for h2, // so we use Serve with a single-connection listener. ln := &singleConnListener{conn: client} errCh <- server.Serve(ln) }() select { case <-ctx.Done(): if err := server.Close(); err != nil { p.log.Debugf("close h2 server: %v", err) } return ctx.Err() case err := <-errCh: if err == http.ErrServerClosed { return nil } return err } } // h2InspectionHandler inspects each HTTP/2 request/response pair. type h2InspectionHandler struct { proxy *Proxy transport http.RoundTripper dst netip.AddrPort sni domain.Domain src SourceInfo } func (h *h2InspectionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { host := hostFromRequest(req, h.sni) if isDomainFronting(req, h.sni) { h.proxy.log.Debugf("domain fronting detected: SNI=%s Host=%s", h.sni.PunycodeString(), host.PunycodeString()) writeBlockPage(w, host) return } action := h.proxy.evaluateAction(h.src.IP, host, h.dst, ProtoH2, req.URL.Path) if action == ActionBlock { h.proxy.log.Debugf("block: H2 %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString()) writeBlockPage(w, host) return } // ICAP REQMOD if h.proxy.icap != nil { modified, err := h.proxy.icap.ReqMod(req) if err != nil { h.proxy.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err) writeBlockPage(w, host) return } req = modified } // Forward to upstream req.URL.Scheme = "https" req.URL.Host = h.sni.PunycodeString() req.RequestURI = "" resp, err := h.transport.RoundTrip(req) if err != nil { h.proxy.log.Debugf("h2 upstream error for %s: %v", host.PunycodeString(), err) http.Error(w, "Bad Gateway", http.StatusBadGateway) return } defer func() { if err := resp.Body.Close(); err != nil { h.proxy.log.Debugf("close h2 resp body: %v", err) } }() // ICAP RESPMOD if h.proxy.icap != nil { modified, err := h.proxy.icap.RespMod(req, resp) if err != nil { h.proxy.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err) writeBlockPage(w, host) return } resp = modified } // Copy response headers and body for k, vals := range resp.Header { for _, v := range vals { w.Header().Add(k, v) } } w.WriteHeader(resp.StatusCode) if _, err := io.Copy(w, resp.Body); err != nil { h.proxy.log.Debugf("h2 response copy error: %v", err) } } // handleWebSocket completes the WebSocket upgrade and relays frames bidirectionally. func (p *Proxy) handleWebSocket(ctx context.Context, req *http.Request, client io.ReadWriter, clientReader *bufio.Reader, remote io.ReadWriter, remoteReader *bufio.Reader) error { if err := req.Write(remote); err != nil { return fmt.Errorf("forward WebSocket upgrade: %w", err) } resp, err := http.ReadResponse(remoteReader, req) if err != nil { return fmt.Errorf("read WebSocket upgrade response: %w", err) } if err := resp.Write(client); err != nil { if closeErr := resp.Body.Close(); closeErr != nil { p.log.Debugf("close ws resp body: %v", closeErr) } return fmt.Errorf("forward WebSocket upgrade response: %w", err) } if err := resp.Body.Close(); err != nil { p.log.Debugf("close ws resp body: %v", err) } if resp.StatusCode != http.StatusSwitchingProtocols { return fmt.Errorf("WebSocket upgrade rejected: status %d", resp.StatusCode) } p.log.Tracef("allow: WebSocket upgrade for %s", req.Host) // Relay WebSocket frames bidirectionally. // clientReader/remoteReader may have buffered data. clientConn := mergeReadWriter(clientReader, client) remoteConn := mergeReadWriter(remoteReader, remote) return relayRW(ctx, clientConn, remoteConn) } // hostFromRequest extracts a domain.Domain from the HTTP request Host header, // falling back to the SNI if Host is empty or an IP. func hostFromRequest(req *http.Request, fallback domain.Domain) domain.Domain { host := req.Host if host == "" { return fallback } // Strip port if present if h, _, err := net.SplitHostPort(host); err == nil { host = h } // If it's an IP address, use the SNI fallback if _, err := netip.ParseAddr(host); err == nil { return fallback } d, err := domain.FromString(host) if err != nil { return fallback } return d } // isDomainFronting detects domain fronting: the Host header doesn't match the // SNI used during the TLS handshake. Only meaningful when SNI is non-empty // (i.e., we're in MITM mode and know the original SNI). func isDomainFronting(req *http.Request, sni domain.Domain) bool { if sni == "" { return false } host := hostFromRequest(req, "") if host == "" { return false } // Host should match SNI or be a subdomain of SNI if host == sni { return false } // Allow www.example.com when SNI is example.com sniStr := sni.PunycodeString() hostStr := host.PunycodeString() if strings.HasSuffix(hostStr, "."+sniStr) { return false } return true } func isWebSocketUpgrade(req *http.Request) bool { return strings.EqualFold(req.Header.Get(headerUpgrade), valueWebSocket) } // writeBlockPage writes the styled HTML block page to an http.ResponseWriter (H2 path). func writeBlockPage(w http.ResponseWriter, host domain.Domain) { hostname := host.PunycodeString() body := fmt.Sprintf(blockPageHTML, hostname, hostname) w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Cache-Control", "no-store") w.WriteHeader(http.StatusForbidden) io.WriteString(w, body) } func writeBlockResponse(w io.Writer, _ *http.Request, host domain.Domain) { hostname := host.PunycodeString() body := fmt.Sprintf(blockPageHTML, hostname, hostname) resp := &http.Response{ StatusCode: http.StatusForbidden, ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), ContentLength: int64(len(body)), Body: io.NopCloser(strings.NewReader(body)), } resp.Header.Set("Content-Type", "text/html; charset=utf-8") resp.Header.Set("Connection", "close") resp.Header.Set("Cache-Control", "no-store") _ = resp.Write(w) } // blockPageHTML is the self-contained HTML block page. // Uses NetBird dark theme with orange accent. Two format args: page title domain, displayed domain. const blockPageHTML = ` Blocked - %s
403 BLOCKED

Access Denied

This connection to %s has been blocked by your organization's network policy.

` // singleConnListener is a net.Listener that yields a single connection. type singleConnListener struct { conn net.Conn once sync.Once ch chan struct{} } func (l *singleConnListener) Accept() (net.Conn, error) { var accepted bool l.once.Do(func() { l.ch = make(chan struct{}) accepted = true }) if accepted { return l.conn, nil } // Block until Close <-l.ch return nil, net.ErrClosed } func (l *singleConnListener) Close() error { l.once.Do(func() { l.ch = make(chan struct{}) }) select { case <-l.ch: default: close(l.ch) } return nil } func (l *singleConnListener) Addr() net.Addr { return l.conn.LocalAddr() } type readWriter struct { io.Reader io.Writer } func mergeReadWriter(r io.Reader, w io.Writer) io.ReadWriter { return &readWriter{Reader: r, Writer: w} } // relayRW copies data bidirectionally between two ReadWriters. func relayRW(ctx context.Context, a, b io.ReadWriter) error { ctx, cancel := context.WithCancel(ctx) defer cancel() errCh := make(chan error, 2) go func() { _, err := io.Copy(b, a) cancel() errCh <- err }() go func() { _, err := io.Copy(a, b) cancel() errCh <- err }() var firstErr error for range 2 { if err := <-errCh; err != nil && firstErr == nil { if !isClosedErr(err) { firstErr = err } } } return firstErr } // hopByHopHeaders are HTTP/1.1 headers that apply to a single connection // and must not be forwarded by a proxy (RFC 7230, Section 6.1). var hopByHopHeaders = []string{ "Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "TE", "Trailers", "Transfer-Encoding", "Upgrade", } // removeHopByHopHeaders strips hop-by-hop headers from h. // Also removes headers listed in the Connection header value. func removeHopByHopHeaders(h http.Header) { // First, remove any headers named in the Connection header for _, connHeader := range h["Connection"] { for _, name := range strings.Split(connHeader, ",") { h.Del(strings.TrimSpace(name)) } } for _, name := range hopByHopHeaders { h.Del(name) } }