fix: handle other fields and ensure cookie coverage

This commit is contained in:
Laurence
2026-03-03 17:51:16 +00:00
parent 877363a686
commit 2902340f7b

69
main.go
View File

@@ -17,6 +17,7 @@ type Config struct {
APIBaseUrl string `json:"apiBaseUrl,omitempty"` APIBaseUrl string `json:"apiBaseUrl,omitempty"`
UserSessionCookieName string `json:"userSessionCookieName,omitempty"` UserSessionCookieName string `json:"userSessionCookieName,omitempty"`
ResourceSessionRequestParam string `json:"resourceSessionRequestParam,omitempty"` ResourceSessionRequestParam string `json:"resourceSessionRequestParam,omitempty"`
AccessTokenQueryParam string `json:"accessTokenQueryParam,omitempty"` // Deprecated: use ResourceSessionRequestParam
DisableForwardAuth bool `json:"disableForwardAuth,omitempty"` DisableForwardAuth bool `json:"disableForwardAuth,omitempty"`
TrustIP []string `json:"trustip,omitempty"` TrustIP []string `json:"trustip,omitempty"`
DisableDefaultCFIPs bool `json:"disableDefaultCFIPs,omitempty"` DisableDefaultCFIPs bool `json:"disableDefaultCFIPs,omitempty"`
@@ -37,6 +38,7 @@ type Badger struct {
apiBaseUrl string apiBaseUrl string
userSessionCookieName string userSessionCookieName string
resourceSessionRequestParam string resourceSessionRequestParam string
accessTokenQueryParam string
disableForwardAuth bool disableForwardAuth bool
trustIP []*net.IPNet trustIP []*net.IPNet
customIPHeader string customIPHeader string
@@ -95,6 +97,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
apiBaseUrl: config.APIBaseUrl, apiBaseUrl: config.APIBaseUrl,
userSessionCookieName: config.UserSessionCookieName, userSessionCookieName: config.UserSessionCookieName,
resourceSessionRequestParam: config.ResourceSessionRequestParam, resourceSessionRequestParam: config.ResourceSessionRequestParam,
accessTokenQueryParam: config.AccessTokenQueryParam,
disableForwardAuth: config.DisableForwardAuth, disableForwardAuth: config.DisableForwardAuth,
customIPHeader: config.CustomIPHeader, customIPHeader: config.CustomIPHeader,
} }
@@ -314,6 +317,8 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.stripSessionCookies(req) p.stripSessionCookies(req)
p.stripSessionParam(req) p.stripSessionParam(req)
req.Header.Del("P-Access-Token-Id")
req.Header.Del("P-Access-Token")
fmt.Println("Badger: Valid session") fmt.Println("Badger: Valid session")
p.next.ServeHTTP(rw, req) p.next.ServeHTTP(rw, req)
@@ -419,39 +424,69 @@ func (p *Badger) getRealIP(req *http.Request) string {
func (p *Badger) stripSessionParam(req *http.Request) { func (p *Badger) stripSessionParam(req *http.Request) {
query := req.URL.Query() query := req.URL.Query()
modified := false
if query.Has(p.resourceSessionRequestParam) { if query.Has(p.resourceSessionRequestParam) {
query.Del(p.resourceSessionRequestParam) query.Del(p.resourceSessionRequestParam)
modified = true
}
if p.accessTokenQueryParam != "" && query.Has(p.accessTokenQueryParam) {
query.Del(p.accessTokenQueryParam)
modified = true
}
if modified {
req.URL.RawQuery = query.Encode() req.URL.RawQuery = query.Encode()
req.RequestURI = req.URL.RequestURI()
} }
} }
// stripSessionCookies removes session cookies from the request before forwarding to the backend. // stripSessionCookies removes session cookies from the request before forwarding to the backend.
// We parse the Cookie header as a string rather than using req.Cookies() + re-encoding so we // Cookie request headers only contain name=value pairs (Set-Cookie attributes like Path/Domain
// preserve the exact header format; the Cookie request header only carries name=value pairs // are response-only), so we filter parsed request cookies and rebuild the Cookie header.
// (domain/path are Set-Cookie response attributes and are not sent in requests per RFC 6265).
func (p *Badger) stripSessionCookies(req *http.Request) { func (p *Badger) stripSessionCookies(req *http.Request) {
cookieHeader := req.Header.Get("Cookie") cookieHeaders := req.Header.Values("Cookie")
if cookieHeader == "" { if len(cookieHeaders) == 0 {
return return
} }
var remaining []string var remaining []*http.Cookie
for part := range strings.SplitSeq(cookieHeader, ";") { for _, headerValue := range cookieHeaders {
part = strings.TrimSpace(part) parsedCookies, err := http.ParseCookie(headerValue)
if part == "" { if err != nil {
// Best-effort fallback for malformed Cookie headers.
for _, part := range strings.Split(headerValue, ";") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
name, value, hasValue := strings.Cut(part, "=")
name = strings.TrimSpace(name)
if strings.HasPrefix(name, p.userSessionCookieName) {
continue
}
if hasValue {
remaining = append(remaining, &http.Cookie{
Name: name,
Value: strings.TrimSpace(value),
})
}
}
continue continue
} }
parts := strings.SplitN(part, "=", 2)
name := strings.TrimSpace(parts[0]) for _, cookie := range parsedCookies {
if !strings.HasPrefix(name, p.userSessionCookieName) { if !strings.HasPrefix(cookie.Name, p.userSessionCookieName) {
remaining = append(remaining, part) remaining = append(remaining, cookie)
}
} }
} }
if len(remaining) > 0 { req.Header.Del("Cookie")
req.Header.Set("Cookie", strings.Join(remaining, "; ")) if len(remaining) == 0 {
} else { return
req.Header.Del("Cookie") }
for _, cookie := range remaining {
req.AddCookie(cookie)
} }
} }