remove SetAllowedFormAction and explicitly set csp header

This commit is contained in:
Elias Schneider
2026-04-19 15:30:23 +02:00
parent d620bc6818
commit 881d3df24e
6 changed files with 94 additions and 95 deletions

View File

@@ -4,9 +4,7 @@ package frontend
import (
"bytes"
"context"
"embed"
"errors"
"fmt"
"io"
"io/fs"
@@ -20,7 +18,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
//go:embed all:dist/*
@@ -58,27 +55,6 @@ func init() {
}
}
// validateRedirectURI validates that the redirect_uri is in the client's allowed callback URLs
func validateRedirectURI(ctx any, oidcService *service.OidcService, clientID, redirectURI string) (bool, error) {
client, err := oidcService.GetClient(ctx.(context.Context), clientID)
if err != nil {
return false, err
}
// If the client has no callback URLs configured, reject the redirect URI
if len(client.CallbackURLs) == 0 {
return false, errors.New("client has no callback URLs configured")
}
// Validate the redirect URI against the client's callback URLs
_, err = utils.GetCallbackURLFromList(client.CallbackURLs, redirectURI)
if err != nil {
return false, err
}
return true, nil
}
func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc, oidcService *service.OidcService) error {
distFS, err := fs.Sub(frontendFS, "dist")
if err != nil {
@@ -116,22 +92,20 @@ func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc, o
}
if path == "index.html" {
nonce := middleware.GetCSPNonce(c)
// Check if this is an OAuth2 authorization request with response_mode=form_post
// In that case, we need to validate and allow form submissions to the redirect_uri
responseMode := c.Query("response_mode")
redirectURI := c.Query("redirect_uri")
clientID := c.Query("client_id")
if responseMode == "form_post" && redirectURI != "" && clientID != "" {
// Validate the redirect_uri against the client's allowlist
isValid, err := validateRedirectURI(c.Request.Context(), oidcService, clientID, redirectURI)
if err == nil && isValid {
// Set the allowed form-action in CSP to include the redirect URI
middleware.SetAllowedFormAction(c, redirectURI)
validatedRedirectURI, err := oidcService.ResolveAllowedCallbackURL(c.Request.Context(), clientID, redirectURI)
if err == nil {
c.Header("Content-Security-Policy", middleware.BuildCSP(nonce, validatedRedirectURI))
}
}
nonce := middleware.GetCSPNonce(c)
// Do not cache the HTML shell, as it embeds a per-request nonce
c.Header("Content-Type", "text/html; charset=utf-8")
c.Header("Cache-Control", "no-store")

View File

@@ -100,12 +100,6 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
return
}
// Set the allowed form-action in CSP after validation (when response_mode is form_post)
// Only set if we have a valid callback URL from the service
if input.ResponseMode == "form_post" && callbackURL != "" {
middleware.SetAllowedFormAction(c, callbackURL)
}
response := dto.AuthorizeOidcClientResponseDto{
Code: code,
CallbackURL: callbackURL,

View File

@@ -23,42 +23,35 @@ func GetCSPNonce(c *gin.Context) string {
return ""
}
// SetAllowedFormAction sets the allowed form-action for the CSP header.
// This is used for OAuth2 response_mode=form_post to allow form submissions to the client's redirect URI.
func SetAllowedFormAction(c *gin.Context, uri string) {
c.Set("csp_allowed_form_action", uri)
}
func (m *CspMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) {
nonce := generateNonce()
c.Set("csp_nonce", nonce)
c.Writer.Header().Set("Content-Security-Policy", BuildCSP(nonce))
// Let the handler run first, then set CSP header with the final context values
c.Next()
// Determine if there is an EXTRA target beyond 'self'
// This is set by handlers (e.g., OIDC authorize) after validating the redirect URI
var extraAction string
if v, ok := c.Get("csp_allowed_form_action"); ok {
extraAction, _ = v.(string)
}
// 'self' is kept in the string; extraAction is just appended
csp := "default-src 'self'; " +
"base-uri 'self'; " +
"object-src 'none'; " +
"frame-ancestors 'none'; " +
"form-action 'self' " + extraAction + "; " +
"img-src * blob:;" +
"font-src 'self'; " +
"style-src 'self' 'unsafe-inline'; " +
"script-src 'self' 'nonce-" + nonce + "'"
c.Writer.Header().Set("Content-Security-Policy", csp)
}
}
func BuildCSP(nonce string, formActionExtra ...string) string {
formAction := "'self'"
for _, extra := range formActionExtra {
if extra != "" {
formAction += " " + extra
}
}
return "default-src 'self'; " +
"base-uri 'self'; " +
"object-src 'none'; " +
"frame-ancestors 'none'; " +
"form-action " + formAction + "; " +
"img-src * blob:;" +
"font-src 'self'; " +
"style-src 'self' 'unsafe-inline'; " +
"script-src 'self' 'nonce-" + nonce + "'"
}
func generateNonce() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {

View File

@@ -0,0 +1,24 @@
package middleware
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBuildCSP(t *testing.T) {
t.Run("uses self form action by default", func(t *testing.T) {
csp := BuildCSP("test-nonce")
assert.Contains(t, csp, "form-action 'self';")
assert.Contains(t, csp, "script-src 'self' 'nonce-test-nonce'")
})
t.Run("adds validated form action targets", func(t *testing.T) {
csp := BuildCSP("test-nonce", "https://example.com/callback")
assert.Contains(t, csp, "form-action 'self' https://example.com/callback;")
assert.Equal(t, 1, strings.Count(csp, "form-action"))
})
}

View File

@@ -683,6 +683,27 @@ func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.Oid
return s.getClientInternal(ctx, clientID, s.db, false)
}
func (s *OidcService) ResolveAllowedCallbackURL(ctx context.Context, clientID, inputCallbackURL string) (string, error) {
client, err := s.GetClient(ctx, clientID)
if err != nil {
return "", err
}
if inputCallbackURL == "" || len(client.CallbackURLs) == 0 {
return "", &common.OidcMissingCallbackURLError{}
}
matched, err := utils.GetCallbackURLFromList(client.CallbackURLs, inputCallbackURL)
if err != nil {
return "", err
}
if matched == "" {
return "", &common.OidcInvalidCallbackURLError{}
}
return matched, nil
}
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) {
var client model.OidcClient
q := tx.

View File

@@ -1,4 +1,4 @@
import test, { expect } from '@playwright/test';
import test, { expect, type Page, type Request } from '@playwright/test';
import { oidcClients, refreshTokens, users } from '../data';
import { cleanupBackend } from '../utils/cleanup.util';
import { generateIdToken, generateOauthAccessToken } from '../utils/jwt.util';
@@ -625,31 +625,21 @@ test('Forces reauthentication when client requires it', async ({ page, request }
expect(webauthnStartCalled).toBe(true);
});
test('Authorize existing client while not signed in with response_mode=form_post', async ({ page }) => {
test('Authorize existing client while not signed in with response_mode=form_post', async ({
page
}) => {
const oidcClient = oidcClients.nextcloud;
const urlParams = createUrlParams(oidcClient);
urlParams.set('response_mode', 'form_post');
await page.context().clearCookies();
// Track if we receive a POST request to the callback URL
let formPostReceived = false;
page.on('request', (request) => {
if (request.method() === 'POST' && request.url().includes(oidcClient.callbackUrl)) {
formPostReceived = true;
}
});
const formPostRequestPromise = waitForFormPostRequest(page, oidcClient.callbackUrl);
await page.goto(`/authorize?${urlParams.toString()}`);
await (await passkeyUtil.init(page)).addPasskey();
await page.getByRole('button', { name: 'Sign in' }).click();
// Wait for the POST request to be made to the callback URL
await expect(() => {
if (!formPostReceived) {
throw new Error('Form POST request not received');
}
}).toBeTruthy();
await expectFormPostRequest(formPostRequestPromise);
});
test('Authorize existing client with response_mode=form_post', async ({ page }) => {
@@ -657,20 +647,23 @@ test('Authorize existing client with response_mode=form_post', async ({ page })
const urlParams = createUrlParams(oidcClient);
urlParams.set('response_mode', 'form_post');
// Track if we receive a POST request to the callback URL
let formPostReceived = false;
page.on('request', (request) => {
if (request.method() === 'POST' && request.url().includes(oidcClient.callbackUrl)) {
formPostReceived = true;
}
});
const formPostRequestPromise = waitForFormPostRequest(page, oidcClient.callbackUrl);
await page.goto(`/authorize?${urlParams.toString()}`);
// Wait for the POST request to be made to the callback URL
await expect(() => {
if (!formPostReceived) {
throw new Error('Form POST request not received');
}
}).toBeTruthy();
});
await expectFormPostRequest(formPostRequestPromise);
});
function waitForFormPostRequest(page: Page, callbackUrl: string): Promise<Request> {
return page.waitForRequest(
(request) => request.method() === 'POST' && request.url() === callbackUrl
);
}
async function expectFormPostRequest(formPostRequestPromise: Promise<Request>) {
const request = await formPostRequestPromise;
const formData = new URLSearchParams(request.postData() ?? '');
expect(formData.get('code')).toBeTruthy();
expect(formData.get('state')).toBe('nXx-6Qr-owc1SHBa');
expect(formData.get('iss')).toBeTruthy();
}