diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 00f41aa50..c1d385375 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -123,7 +123,7 @@ func runServer(cmd *cobra.Command, args []string) error { _ = util.InitLogger(logger, level, util.LogConsole) - log.Infof("configured log level: %s", level) + logger.Infof("configured log level: %s", level) switch forwardedProto { case "auto", "http", "https": @@ -171,7 +171,7 @@ func runServer(cmd *cobra.Command, args []string) error { defer stop() if err := srv.ListenAndServe(ctx, addr); err != nil { - log.Error(err) + logger.Error(err) return err } return nil diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index b6cae365b..8a966faa3 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -90,10 +90,8 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { if err != nil { host = r.Host } - mw.domainsMux.RLock() - config, exists := mw.domains[host] - mw.domainsMux.RUnlock() + config, exists := mw.getDomainConfig(host) mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists) // Domains that are not configured here or have no authentication schemes applied should simply pass through. @@ -103,115 +101,160 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler { } // Set account and service IDs in captured data for access logging. - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetAccountId(types.AccountID(config.AccountID)) - cd.SetServiceId(config.ServiceID) - } + setCapturedIDs(r, config) - // Check for error from OAuth callback (e.g., access denied) - if errCode := r.URL.Query().Get("error"); errCode != "" { - var requestID string - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(proxy.OriginAuth) - cd.SetAuthMethod(auth.MethodOIDC.String()) - requestID = cd.GetRequestID() - } - errDesc := r.URL.Query().Get("error_description") - if errDesc == "" { - errDesc = "An error occurred during authentication" - } - web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID) + if mw.handleOAuthCallbackError(w, r) { return } - // Check for an existing session cookie (contains JWT) - if cookie, err := r.Cookie(auth.SessionCookieName); err == nil { - if userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey); err == nil { - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetUserID(userID) - cd.SetAuthMethod(method) - } - next.ServeHTTP(w, r) - return - } + if mw.forwardWithSessionCookie(w, r, host, config, next) { + return } - // Try to authenticate with each scheme. - methods := make(map[string]string) - var attemptedMethod string - for _, scheme := range config.Schemes { - token, promptData, err := scheme.Authenticate(r) - if err != nil { - mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err) - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(proxy.OriginAuth) - } - http.Error(w, "authentication service unavailable", http.StatusBadGateway) - return + mw.authenticateWithSchemes(w, r, host, config) + }) +} + +func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) { + mw.domainsMux.RLock() + defer mw.domainsMux.RUnlock() + config, exists := mw.domains[host] + return config, exists +} + +func setCapturedIDs(r *http.Request, config DomainConfig) { + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetAccountId(types.AccountID(config.AccountID)) + cd.SetServiceId(config.ServiceID) + } +} + +// handleOAuthCallbackError checks for error query parameters from an OAuth +// callback and renders the access denied page if present. +func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool { + errCode := r.URL.Query().Get("error") + if errCode == "" { + return false + } + + var requestID string + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + cd.SetAuthMethod(auth.MethodOIDC.String()) + requestID = cd.GetRequestID() + } + errDesc := r.URL.Query().Get("error_description") + if errDesc == "" { + errDesc = "An error occurred during authentication" + } + web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID) + return true +} + +// forwardWithSessionCookie checks for a valid session cookie and, if found, +// sets the user identity on the request context and forwards to the next handler. +func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool { + cookie, err := r.Cookie(auth.SessionCookieName) + if err != nil { + return false + } + userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey) + if err != nil { + return false + } + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetUserID(userID) + cd.SetAuthMethod(method) + } + next.ServeHTTP(w, r) + return true +} + +// authenticateWithSchemes tries each configured auth scheme in order. +// On success it sets a session cookie and redirects; on failure it renders the login page. +func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) { + methods := make(map[string]string) + var attemptedMethod string + + for _, scheme := range config.Schemes { + token, promptData, err := scheme.Authenticate(r) + if err != nil { + mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err) + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) } - - // Track if credentials were submitted but auth failed - if token == "" && wasCredentialSubmitted(r, scheme.Type()) { - attemptedMethod = scheme.Type().String() - } - - if token != "" { - result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type()) - if err != nil { - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(proxy.OriginAuth) - cd.SetAuthMethod(scheme.Type().String()) - } - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if !result.Valid { - var requestID string - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(proxy.OriginAuth) - cd.SetUserID(result.UserID) - cd.SetAuthMethod(scheme.Type().String()) - requestID = cd.GetRequestID() - } - web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID) - return - } - - expiration := config.SessionExpiration - if expiration == 0 { - expiration = auth.DefaultSessionExpiry - } - http.SetCookie(w, &http.Cookie{ - Name: auth.SessionCookieName, - Value: token, - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteLaxMode, - MaxAge: int(expiration.Seconds()), - }) - - // Redirect instead of forwarding the auth POST to the backend. - // The browser will follow with a GET carrying the new session cookie. - if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { - cd.SetOrigin(proxy.OriginAuth) - cd.SetUserID(result.UserID) - cd.SetAuthMethod(scheme.Type().String()) - } - redirectURL := stripSessionTokenParam(r.URL) - http.Redirect(w, r, redirectURL, http.StatusSeeOther) - return - } - methods[scheme.Type().String()] = promptData + http.Error(w, "authentication service unavailable", http.StatusBadGateway) + return } + // Track if credentials were submitted but auth failed + if token == "" && wasCredentialSubmitted(r, scheme.Type()) { + attemptedMethod = scheme.Type().String() + } + + if token != "" { + mw.handleAuthenticatedToken(w, r, host, token, config, scheme) + return + } + methods[scheme.Type().String()] = promptData + } + + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + if attemptedMethod != "" { + cd.SetAuthMethod(attemptedMethod) + } + } + web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized) +} + +// handleAuthenticatedToken validates the token, handles denied access, and on +// success sets a session cookie and redirects to the original URL. +func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) { + result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type()) + if err != nil { if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { cd.SetOrigin(proxy.OriginAuth) - if attemptedMethod != "" { - cd.SetAuthMethod(attemptedMethod) - } + cd.SetAuthMethod(scheme.Type().String()) } - web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if !result.Valid { + var requestID string + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + cd.SetUserID(result.UserID) + cd.SetAuthMethod(scheme.Type().String()) + requestID = cd.GetRequestID() + } + web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID) + return + } + + expiration := config.SessionExpiration + if expiration == 0 { + expiration = auth.DefaultSessionExpiry + } + http.SetCookie(w, &http.Cookie{ + Name: auth.SessionCookieName, + Value: token, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: int(expiration.Seconds()), }) + + // Redirect instead of forwarding the auth POST to the backend. + // The browser will follow with a GET carrying the new session cookie. + if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil { + cd.SetOrigin(proxy.OriginAuth) + cd.SetUserID(result.UserID) + cd.SetAuthMethod(scheme.Type().String()) + } + redirectURL := stripSessionTokenParam(r.URL) + http.Redirect(w, r, redirectURL, http.StatusSeeOther) } // wasCredentialSubmitted checks if credentials were submitted for the given auth method.