From 881d3df24eb0cb7650496a32f82cf578a6d236fe Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Sun, 19 Apr 2026 15:30:23 +0200 Subject: [PATCH] remove `SetAllowedFormAction` and explicitly set csp header --- backend/frontend/frontend_included.go | 36 ++---------- .../internal/controller/oidc_controller.go | 6 -- backend/internal/middleware/csp_middleware.go | 47 +++++++--------- .../middleware/csp_middleware_test.go | 24 ++++++++ backend/internal/service/oidc_service.go | 21 +++++++ tests/specs/oidc.spec.ts | 55 ++++++++----------- 6 files changed, 94 insertions(+), 95 deletions(-) create mode 100644 backend/internal/middleware/csp_middleware_test.go diff --git a/backend/frontend/frontend_included.go b/backend/frontend/frontend_included.go index 840e8fa6..f2258257 100644 --- a/backend/frontend/frontend_included.go +++ b/backend/frontend/frontend_included.go @@ -4,9 +4,7 @@ package frontend import ( "bytes" - "context" "embed" - "errors" "fmt" "io" "io/fs" @@ -20,7 +18,6 @@ import ( "github.com/gin-gonic/gin" "github.com/pocket-id/pocket-id/backend/internal/middleware" "github.com/pocket-id/pocket-id/backend/internal/service" - "github.com/pocket-id/pocket-id/backend/internal/utils" ) //go:embed all:dist/* @@ -58,27 +55,6 @@ func init() { } } -// validateRedirectURI validates that the redirect_uri is in the client's allowed callback URLs -func validateRedirectURI(ctx any, oidcService *service.OidcService, clientID, redirectURI string) (bool, error) { - client, err := oidcService.GetClient(ctx.(context.Context), clientID) - if err != nil { - return false, err - } - - // If the client has no callback URLs configured, reject the redirect URI - if len(client.CallbackURLs) == 0 { - return false, errors.New("client has no callback URLs configured") - } - - // Validate the redirect URI against the client's callback URLs - _, err = utils.GetCallbackURLFromList(client.CallbackURLs, redirectURI) - if err != nil { - return false, err - } - - return true, nil -} - func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc, oidcService *service.OidcService) error { distFS, err := fs.Sub(frontendFS, "dist") if err != nil { @@ -116,22 +92,20 @@ func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc, o } if path == "index.html" { + nonce := middleware.GetCSPNonce(c) + // Check if this is an OAuth2 authorization request with response_mode=form_post // In that case, we need to validate and allow form submissions to the redirect_uri responseMode := c.Query("response_mode") redirectURI := c.Query("redirect_uri") clientID := c.Query("client_id") if responseMode == "form_post" && redirectURI != "" && clientID != "" { - // Validate the redirect_uri against the client's allowlist - isValid, err := validateRedirectURI(c.Request.Context(), oidcService, clientID, redirectURI) - if err == nil && isValid { - // Set the allowed form-action in CSP to include the redirect URI - middleware.SetAllowedFormAction(c, redirectURI) + validatedRedirectURI, err := oidcService.ResolveAllowedCallbackURL(c.Request.Context(), clientID, redirectURI) + if err == nil { + c.Header("Content-Security-Policy", middleware.BuildCSP(nonce, validatedRedirectURI)) } } - nonce := middleware.GetCSPNonce(c) - // Do not cache the HTML shell, as it embeds a per-request nonce c.Header("Content-Type", "text/html; charset=utf-8") c.Header("Cache-Control", "no-store") diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 51a9681b..193a6723 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -100,12 +100,6 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) { return } - // Set the allowed form-action in CSP after validation (when response_mode is form_post) - // Only set if we have a valid callback URL from the service - if input.ResponseMode == "form_post" && callbackURL != "" { - middleware.SetAllowedFormAction(c, callbackURL) - } - response := dto.AuthorizeOidcClientResponseDto{ Code: code, CallbackURL: callbackURL, diff --git a/backend/internal/middleware/csp_middleware.go b/backend/internal/middleware/csp_middleware.go index 92d8220a..2fb0633d 100644 --- a/backend/internal/middleware/csp_middleware.go +++ b/backend/internal/middleware/csp_middleware.go @@ -23,42 +23,35 @@ func GetCSPNonce(c *gin.Context) string { return "" } -// SetAllowedFormAction sets the allowed form-action for the CSP header. -// This is used for OAuth2 response_mode=form_post to allow form submissions to the client's redirect URI. -func SetAllowedFormAction(c *gin.Context, uri string) { - c.Set("csp_allowed_form_action", uri) -} - func (m *CspMiddleware) Add() gin.HandlerFunc { return func(c *gin.Context) { nonce := generateNonce() c.Set("csp_nonce", nonce) + c.Writer.Header().Set("Content-Security-Policy", BuildCSP(nonce)) - // Let the handler run first, then set CSP header with the final context values c.Next() - - // Determine if there is an EXTRA target beyond 'self' - // This is set by handlers (e.g., OIDC authorize) after validating the redirect URI - var extraAction string - if v, ok := c.Get("csp_allowed_form_action"); ok { - extraAction, _ = v.(string) - } - - // 'self' is kept in the string; extraAction is just appended - csp := "default-src 'self'; " + - "base-uri 'self'; " + - "object-src 'none'; " + - "frame-ancestors 'none'; " + - "form-action 'self' " + extraAction + "; " + - "img-src * blob:;" + - "font-src 'self'; " + - "style-src 'self' 'unsafe-inline'; " + - "script-src 'self' 'nonce-" + nonce + "'" - - c.Writer.Header().Set("Content-Security-Policy", csp) } } +func BuildCSP(nonce string, formActionExtra ...string) string { + formAction := "'self'" + for _, extra := range formActionExtra { + if extra != "" { + formAction += " " + extra + } + } + + return "default-src 'self'; " + + "base-uri 'self'; " + + "object-src 'none'; " + + "frame-ancestors 'none'; " + + "form-action " + formAction + "; " + + "img-src * blob:;" + + "font-src 'self'; " + + "style-src 'self' 'unsafe-inline'; " + + "script-src 'self' 'nonce-" + nonce + "'" +} + func generateNonce() string { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { diff --git a/backend/internal/middleware/csp_middleware_test.go b/backend/internal/middleware/csp_middleware_test.go new file mode 100644 index 00000000..67f90464 --- /dev/null +++ b/backend/internal/middleware/csp_middleware_test.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuildCSP(t *testing.T) { + t.Run("uses self form action by default", func(t *testing.T) { + csp := BuildCSP("test-nonce") + + assert.Contains(t, csp, "form-action 'self';") + assert.Contains(t, csp, "script-src 'self' 'nonce-test-nonce'") + }) + + t.Run("adds validated form action targets", func(t *testing.T) { + csp := BuildCSP("test-nonce", "https://example.com/callback") + + assert.Contains(t, csp, "form-action 'self' https://example.com/callback;") + assert.Equal(t, 1, strings.Count(csp, "form-action")) + }) +} diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 9d0b7ff0..0743ab7c 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -683,6 +683,27 @@ func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.Oid return s.getClientInternal(ctx, clientID, s.db, false) } +func (s *OidcService) ResolveAllowedCallbackURL(ctx context.Context, clientID, inputCallbackURL string) (string, error) { + client, err := s.GetClient(ctx, clientID) + if err != nil { + return "", err + } + + if inputCallbackURL == "" || len(client.CallbackURLs) == 0 { + return "", &common.OidcMissingCallbackURLError{} + } + + matched, err := utils.GetCallbackURLFromList(client.CallbackURLs, inputCallbackURL) + if err != nil { + return "", err + } + if matched == "" { + return "", &common.OidcInvalidCallbackURLError{} + } + + return matched, nil +} + func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) { var client model.OidcClient q := tx. diff --git a/tests/specs/oidc.spec.ts b/tests/specs/oidc.spec.ts index 7bdcef33..4819d132 100644 --- a/tests/specs/oidc.spec.ts +++ b/tests/specs/oidc.spec.ts @@ -1,4 +1,4 @@ -import test, { expect } from '@playwright/test'; +import test, { expect, type Page, type Request } from '@playwright/test'; import { oidcClients, refreshTokens, users } from '../data'; import { cleanupBackend } from '../utils/cleanup.util'; import { generateIdToken, generateOauthAccessToken } from '../utils/jwt.util'; @@ -625,31 +625,21 @@ test('Forces reauthentication when client requires it', async ({ page, request } expect(webauthnStartCalled).toBe(true); }); -test('Authorize existing client while not signed in with response_mode=form_post', async ({ page }) => { +test('Authorize existing client while not signed in with response_mode=form_post', async ({ + page +}) => { const oidcClient = oidcClients.nextcloud; const urlParams = createUrlParams(oidcClient); urlParams.set('response_mode', 'form_post'); await page.context().clearCookies(); - // Track if we receive a POST request to the callback URL - let formPostReceived = false; - page.on('request', (request) => { - if (request.method() === 'POST' && request.url().includes(oidcClient.callbackUrl)) { - formPostReceived = true; - } - }); - + const formPostRequestPromise = waitForFormPostRequest(page, oidcClient.callbackUrl); await page.goto(`/authorize?${urlParams.toString()}`); await (await passkeyUtil.init(page)).addPasskey(); await page.getByRole('button', { name: 'Sign in' }).click(); - // Wait for the POST request to be made to the callback URL - await expect(() => { - if (!formPostReceived) { - throw new Error('Form POST request not received'); - } - }).toBeTruthy(); + await expectFormPostRequest(formPostRequestPromise); }); test('Authorize existing client with response_mode=form_post', async ({ page }) => { @@ -657,20 +647,23 @@ test('Authorize existing client with response_mode=form_post', async ({ page }) const urlParams = createUrlParams(oidcClient); urlParams.set('response_mode', 'form_post'); - // Track if we receive a POST request to the callback URL - let formPostReceived = false; - page.on('request', (request) => { - if (request.method() === 'POST' && request.url().includes(oidcClient.callbackUrl)) { - formPostReceived = true; - } - }); - + const formPostRequestPromise = waitForFormPostRequest(page, oidcClient.callbackUrl); await page.goto(`/authorize?${urlParams.toString()}`); - // Wait for the POST request to be made to the callback URL - await expect(() => { - if (!formPostReceived) { - throw new Error('Form POST request not received'); - } - }).toBeTruthy(); -}); \ No newline at end of file + await expectFormPostRequest(formPostRequestPromise); +}); + +function waitForFormPostRequest(page: Page, callbackUrl: string): Promise { + return page.waitForRequest( + (request) => request.method() === 'POST' && request.url() === callbackUrl + ); +} + +async function expectFormPostRequest(formPostRequestPromise: Promise) { + const request = await formPostRequestPromise; + const formData = new URLSearchParams(request.postData() ?? ''); + + expect(formData.get('code')).toBeTruthy(); + expect(formData.get('state')).toBe('nXx-6Qr-owc1SHBa'); + expect(formData.get('iss')).toBeTruthy(); +}