diff --git a/main.go b/main.go index fd4bd52..5d151ae 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ type Config struct { APIBaseUrl string `json:"apiBaseUrl,omitempty"` UserSessionCookieName string `json:"userSessionCookieName,omitempty"` ResourceSessionRequestParam string `json:"resourceSessionRequestParam,omitempty"` + AccessTokenQueryParam string `json:"accessTokenQueryParam,omitempty"` // Deprecated: use ResourceSessionRequestParam DisableForwardAuth bool `json:"disableForwardAuth,omitempty"` TrustIP []string `json:"trustip,omitempty"` DisableDefaultCFIPs bool `json:"disableDefaultCFIPs,omitempty"` @@ -37,6 +38,7 @@ type Badger struct { apiBaseUrl string userSessionCookieName string resourceSessionRequestParam string + accessTokenQueryParam string disableForwardAuth bool trustIP []*net.IPNet customIPHeader string @@ -95,6 +97,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h apiBaseUrl: config.APIBaseUrl, userSessionCookieName: config.UserSessionCookieName, resourceSessionRequestParam: config.ResourceSessionRequestParam, + accessTokenQueryParam: config.AccessTokenQueryParam, disableForwardAuth: config.DisableForwardAuth, customIPHeader: config.CustomIPHeader, } @@ -314,6 +317,8 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.stripSessionCookies(req) p.stripSessionParam(req) + req.Header.Del("P-Access-Token-Id") + req.Header.Del("P-Access-Token") fmt.Println("Badger: Valid session") p.next.ServeHTTP(rw, req) @@ -419,39 +424,69 @@ func (p *Badger) getRealIP(req *http.Request) string { func (p *Badger) stripSessionParam(req *http.Request) { query := req.URL.Query() + modified := false if query.Has(p.resourceSessionRequestParam) { query.Del(p.resourceSessionRequestParam) + modified = true + } + if p.accessTokenQueryParam != "" && query.Has(p.accessTokenQueryParam) { + query.Del(p.accessTokenQueryParam) + modified = true + } + if modified { req.URL.RawQuery = query.Encode() + req.RequestURI = req.URL.RequestURI() } } // stripSessionCookies removes session cookies from the request before forwarding to the backend. -// We parse the Cookie header as a string rather than using req.Cookies() + re-encoding so we -// preserve the exact header format; the Cookie request header only carries name=value pairs -// (domain/path are Set-Cookie response attributes and are not sent in requests per RFC 6265). +// Cookie request headers only contain name=value pairs (Set-Cookie attributes like Path/Domain +// are response-only), so we filter parsed request cookies and rebuild the Cookie header. func (p *Badger) stripSessionCookies(req *http.Request) { - cookieHeader := req.Header.Get("Cookie") - if cookieHeader == "" { + cookieHeaders := req.Header.Values("Cookie") + if len(cookieHeaders) == 0 { return } - var remaining []string - for part := range strings.SplitSeq(cookieHeader, ";") { - part = strings.TrimSpace(part) - if part == "" { + var remaining []*http.Cookie + for _, headerValue := range cookieHeaders { + parsedCookies, err := http.ParseCookie(headerValue) + if err != nil { + // Best-effort fallback for malformed Cookie headers. + for _, part := range strings.Split(headerValue, ";") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + name, value, hasValue := strings.Cut(part, "=") + name = strings.TrimSpace(name) + if strings.HasPrefix(name, p.userSessionCookieName) { + continue + } + if hasValue { + remaining = append(remaining, &http.Cookie{ + Name: name, + Value: strings.TrimSpace(value), + }) + } + } continue } - parts := strings.SplitN(part, "=", 2) - name := strings.TrimSpace(parts[0]) - if !strings.HasPrefix(name, p.userSessionCookieName) { - remaining = append(remaining, part) + + for _, cookie := range parsedCookies { + if !strings.HasPrefix(cookie.Name, p.userSessionCookieName) { + remaining = append(remaining, cookie) + } } } - if len(remaining) > 0 { - req.Header.Set("Cookie", strings.Join(remaining, "; ")) - } else { - req.Header.Del("Cookie") + req.Header.Del("Cookie") + if len(remaining) == 0 { + return + } + + for _, cookie := range remaining { + req.AddCookie(cookie) } }