diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 13f0029a4..c298b7f79 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -23,8 +23,14 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { // headers that we wish to use to gather that information on the request. sourceIp := extractSourceIP(r) + // Create a mutable struct to capture data from downstream handlers. + // We pass a pointer in the context - the pointer itself flows down immutably, + // but the struct it points to can be mutated by inner handlers. + capturedData := &proxy.CapturedData{} + ctx := proxy.WithCapturedData(r.Context(), capturedData) + start := time.Now() - next.ServeHTTP(sw, r) + next.ServeHTTP(sw, r.WithContext(ctx)) duration := time.Since(start) host, _, err := net.SplitHostPort(r.Host) @@ -35,8 +41,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { entry := logEntry{ ID: xid.New().String(), - ServiceId: proxy.ServiceIdFromContext(r.Context()), - AccountID: proxy.AccountIdFromContext(r.Context()), + ServiceId: capturedData.GetServiceId(), + AccountID: capturedData.GetAccountId(), Host: host, Path: r.URL.Path, DurationMs: duration.Milliseconds(), diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index 2aa4d699b..b437a9610 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -28,6 +28,13 @@ func (c *CapturedData) SetServiceId(serviceId string) { c.ServiceId = serviceId } +// GetServiceId safely gets the service ID +func (c *CapturedData) GetServiceId() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.ServiceId +} + // SetAccountId safely sets the account ID func (c *CapturedData) SetAccountId(accountId string) { c.mu.Lock() @@ -35,6 +42,13 @@ func (c *CapturedData) SetAccountId(accountId string) { c.AccountId = accountId } +// GetAccountId safely gets the account ID +func (c *CapturedData) GetAccountId() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.AccountId +} + // WithCapturedData adds a CapturedData struct to the context func WithCapturedData(ctx context.Context, data *CapturedData) context.Context { return context.WithValue(ctx, capturedDataKey, data) diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index a06699653..e90850961 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -39,6 +39,14 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Set the accountId in the context for later retrieval. ctx = withAccountId(ctx, accountID) + // Also populate captured data if it exists (allows middleware to read after handler completes). + // This solves the problem of passing data UP the middleware chain: we put a mutable struct + // pointer in the context, and mutate the struct here so outer middleware can read it. + if capturedData := CapturedDataFromContext(ctx); capturedData != nil { + capturedData.SetServiceId(serviceId) + capturedData.SetAccountId(accountID) + } + // Set up a reverse proxy using the transport and then use it to serve the request. proxy := httputil.NewSingleHostReverseProxy(target) proxy.Transport = p.transport diff --git a/proxy/server.go b/proxy/server.go index ae95546c2..0c0b7d2d3 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -170,7 +170,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Finally, start the reverse proxy. s.https = &http.Server{ Addr: addr, - Handler: s.auth.Protect(accessLog.Middleware(s.proxy)), + Handler: accessLog.Middleware(s.auth.Protect(s.proxy)), TLSConfig: tlsConfig, } return s.https.ListenAndServeTLS("", "")