From b151c323482f275ace5289f45ff401906ea4bb7b Mon Sep 17 00:00:00 2001 From: Milo Schwartz Date: Wed, 27 Nov 2024 00:08:20 -0500 Subject: [PATCH] refactor to support resource session cookies at base domain --- main.go | 86 +++++++++++++-------------------------------------------- 1 file changed, 19 insertions(+), 67 deletions(-) diff --git a/main.go b/main.go index ae09c66..f3e8980 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" ) type Config struct { @@ -15,19 +16,14 @@ type Config struct { ResourceSessionCookieName string `json:"resourceSessionCookieName"` } -type SessionData struct { - Session *string `json:"session"` - ResourceSession *string `json:"resource_session"` -} - type VerifyBody struct { - Sessions SessionData `json:"sessions"` - OriginalRequestURL string `json:"originalRequestURL"` - RequestScheme *string `json:"scheme"` - RequestHost *string `json:"host"` - RequestPath *string `json:"path"` - RequestMethod *string `json:"method"` - TLS bool `json:"tls"` + Sessions map[string]string `json:"sessions"` + OriginalRequestURL string `json:"originalRequestURL"` + RequestScheme *string `json:"scheme"` + RequestHost *string `json:"host"` + RequestPath *string `json:"path"` + RequestMethod *string `json:"method"` + TLS bool `json:"tls"` } type VerifyResponse struct { @@ -62,42 +58,13 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - fmt.Println("config values are: ", p.apiBaseUrl, p.resourceSessionCookieName, p.sessionQueryParameter, p.userSessionCookieName) - - sess := req.URL.Query().Get(p.sessionQueryParameter) - if sess != "" { - http.SetCookie(rw, &http.Cookie{ - Name: p.resourceSessionCookieName, - Value: sess, - Path: "/", - Domain: req.Host, - }) - - query := req.URL.Query() - query.Del(p.sessionQueryParameter) - req.URL.RawQuery = query.Encode() - } - - fmt.Println("checked for session param") - cookies := p.extractCookies(req) - if sess != "" { - cookies.Session = &sess - } - - fmt.Println("extracted cookies") verifyURL := fmt.Sprintf("%s/badger/verify-session", p.apiBaseUrl) - - fmt.Println("verify url is", verifyURL) - originalRequestURL := fmt.Sprintf("%s://%s%s", p.getScheme(req), req.Host, req.URL.RequestURI()) cookieData := VerifyBody{ - Sessions: SessionData{ - Session: cookies.Session, - ResourceSession: cookies.ResourceSession, - }, + Sessions: cookies, OriginalRequestURL: originalRequestURL, RequestScheme: &req.URL.Scheme, RequestHost: &req.Host, @@ -106,15 +73,13 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { TLS: req.TLS != nil, } - fmt.Println("built verify body") - jsonData, err := json.Marshal(cookieData) if err != nil { http.Error(rw, "Internal Server Error", http.StatusInternalServerError) // TODO: redirect to error page return } - fmt.Println("JSON data:", string(jsonData)) + fmt.Println("verify request", string(jsonData)) resp, err := http.Post(verifyURL, "application/json", bytes.NewBuffer(jsonData)) if err != nil { @@ -123,15 +88,11 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } defer resp.Body.Close() - fmt.Println("response status code:", resp.StatusCode) - if resp.StatusCode != http.StatusOK { http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } - fmt.Println("de marshalling response") - var result VerifyResponse err = json.NewDecoder(resp.Body).Decode(&result) if err != nil { @@ -139,19 +100,12 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - fmt.Println("handling response") - - if result.Valid { - p.next.ServeHTTP(rw, req) + if result.Data.RedirectURL != nil && *result.Data.RedirectURL != "" { + http.Redirect(rw, req, *result.Data.RedirectURL, http.StatusFound) return } - if result.RedirectURL != nil && *result.RedirectURL != "" { - http.Redirect(rw, req, *result.RedirectURL, http.StatusFound) - return - } - - if !result.Valid { // only do this if for some reason the API doesn't return a redirect and it's not valid + if !result.Data.Valid { // only do this if for some reason the API doesn't return a redirect and it's not valid http.Error(rw, "Unauthorized", http.StatusUnauthorized) return } @@ -161,15 +115,13 @@ func (p *Badger) ServeHTTP(rw http.ResponseWriter, req *http.Request) { p.next.ServeHTTP(rw, req) } -func (p *Badger) extractCookies(req *http.Request) SessionData { - var cookies SessionData +func (p *Badger) extractCookies(req *http.Request) map[string]string { + cookies := make(map[string]string) - if appSSOSessionCookie, err := req.Cookie(p.userSessionCookieName); err == nil { - cookies.Session = &appSSOSessionCookie.Value - } - - if resourceSessionCookie, err := req.Cookie(p.resourceSessionCookieName); err == nil { - cookies.ResourceSession = &resourceSessionCookie.Value + for _, cookie := range req.Cookies() { + if strings.HasPrefix(cookie.Name, p.userSessionCookieName) || strings.HasPrefix(cookie.Name, p.resourceSessionCookieName) { + cookies[cookie.Name] = cookie.Value + } } return cookies