mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-05-13 00:19:52 +00:00
remove SetAllowedFormAction and explicitly set csp header
This commit is contained in:
@@ -4,9 +4,7 @@ package frontend
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -20,7 +18,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
|
||||
//go:embed all:dist/*
|
||||
@@ -58,27 +55,6 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// validateRedirectURI validates that the redirect_uri is in the client's allowed callback URLs
|
||||
func validateRedirectURI(ctx any, oidcService *service.OidcService, clientID, redirectURI string) (bool, error) {
|
||||
client, err := oidcService.GetClient(ctx.(context.Context), clientID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// If the client has no callback URLs configured, reject the redirect URI
|
||||
if len(client.CallbackURLs) == 0 {
|
||||
return false, errors.New("client has no callback URLs configured")
|
||||
}
|
||||
|
||||
// Validate the redirect URI against the client's callback URLs
|
||||
_, err = utils.GetCallbackURLFromList(client.CallbackURLs, redirectURI)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc, oidcService *service.OidcService) error {
|
||||
distFS, err := fs.Sub(frontendFS, "dist")
|
||||
if err != nil {
|
||||
@@ -116,22 +92,20 @@ func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc, o
|
||||
}
|
||||
|
||||
if path == "index.html" {
|
||||
nonce := middleware.GetCSPNonce(c)
|
||||
|
||||
// Check if this is an OAuth2 authorization request with response_mode=form_post
|
||||
// In that case, we need to validate and allow form submissions to the redirect_uri
|
||||
responseMode := c.Query("response_mode")
|
||||
redirectURI := c.Query("redirect_uri")
|
||||
clientID := c.Query("client_id")
|
||||
if responseMode == "form_post" && redirectURI != "" && clientID != "" {
|
||||
// Validate the redirect_uri against the client's allowlist
|
||||
isValid, err := validateRedirectURI(c.Request.Context(), oidcService, clientID, redirectURI)
|
||||
if err == nil && isValid {
|
||||
// Set the allowed form-action in CSP to include the redirect URI
|
||||
middleware.SetAllowedFormAction(c, redirectURI)
|
||||
validatedRedirectURI, err := oidcService.ResolveAllowedCallbackURL(c.Request.Context(), clientID, redirectURI)
|
||||
if err == nil {
|
||||
c.Header("Content-Security-Policy", middleware.BuildCSP(nonce, validatedRedirectURI))
|
||||
}
|
||||
}
|
||||
|
||||
nonce := middleware.GetCSPNonce(c)
|
||||
|
||||
// Do not cache the HTML shell, as it embeds a per-request nonce
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-store")
|
||||
|
||||
@@ -100,12 +100,6 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set the allowed form-action in CSP after validation (when response_mode is form_post)
|
||||
// Only set if we have a valid callback URL from the service
|
||||
if input.ResponseMode == "form_post" && callbackURL != "" {
|
||||
middleware.SetAllowedFormAction(c, callbackURL)
|
||||
}
|
||||
|
||||
response := dto.AuthorizeOidcClientResponseDto{
|
||||
Code: code,
|
||||
CallbackURL: callbackURL,
|
||||
|
||||
@@ -23,42 +23,35 @@ func GetCSPNonce(c *gin.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetAllowedFormAction sets the allowed form-action for the CSP header.
|
||||
// This is used for OAuth2 response_mode=form_post to allow form submissions to the client's redirect URI.
|
||||
func SetAllowedFormAction(c *gin.Context, uri string) {
|
||||
c.Set("csp_allowed_form_action", uri)
|
||||
}
|
||||
|
||||
func (m *CspMiddleware) Add() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
nonce := generateNonce()
|
||||
c.Set("csp_nonce", nonce)
|
||||
c.Writer.Header().Set("Content-Security-Policy", BuildCSP(nonce))
|
||||
|
||||
// Let the handler run first, then set CSP header with the final context values
|
||||
c.Next()
|
||||
|
||||
// Determine if there is an EXTRA target beyond 'self'
|
||||
// This is set by handlers (e.g., OIDC authorize) after validating the redirect URI
|
||||
var extraAction string
|
||||
if v, ok := c.Get("csp_allowed_form_action"); ok {
|
||||
extraAction, _ = v.(string)
|
||||
}
|
||||
|
||||
// 'self' is kept in the string; extraAction is just appended
|
||||
csp := "default-src 'self'; " +
|
||||
"base-uri 'self'; " +
|
||||
"object-src 'none'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"form-action 'self' " + extraAction + "; " +
|
||||
"img-src * blob:;" +
|
||||
"font-src 'self'; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
"script-src 'self' 'nonce-" + nonce + "'"
|
||||
|
||||
c.Writer.Header().Set("Content-Security-Policy", csp)
|
||||
}
|
||||
}
|
||||
|
||||
func BuildCSP(nonce string, formActionExtra ...string) string {
|
||||
formAction := "'self'"
|
||||
for _, extra := range formActionExtra {
|
||||
if extra != "" {
|
||||
formAction += " " + extra
|
||||
}
|
||||
}
|
||||
|
||||
return "default-src 'self'; " +
|
||||
"base-uri 'self'; " +
|
||||
"object-src 'none'; " +
|
||||
"frame-ancestors 'none'; " +
|
||||
"form-action " + formAction + "; " +
|
||||
"img-src * blob:;" +
|
||||
"font-src 'self'; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
"script-src 'self' 'nonce-" + nonce + "'"
|
||||
}
|
||||
|
||||
func generateNonce() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
|
||||
24
backend/internal/middleware/csp_middleware_test.go
Normal file
24
backend/internal/middleware/csp_middleware_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuildCSP(t *testing.T) {
|
||||
t.Run("uses self form action by default", func(t *testing.T) {
|
||||
csp := BuildCSP("test-nonce")
|
||||
|
||||
assert.Contains(t, csp, "form-action 'self';")
|
||||
assert.Contains(t, csp, "script-src 'self' 'nonce-test-nonce'")
|
||||
})
|
||||
|
||||
t.Run("adds validated form action targets", func(t *testing.T) {
|
||||
csp := BuildCSP("test-nonce", "https://example.com/callback")
|
||||
|
||||
assert.Contains(t, csp, "form-action 'self' https://example.com/callback;")
|
||||
assert.Equal(t, 1, strings.Count(csp, "form-action"))
|
||||
})
|
||||
}
|
||||
@@ -683,6 +683,27 @@ func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.Oid
|
||||
return s.getClientInternal(ctx, clientID, s.db, false)
|
||||
}
|
||||
|
||||
func (s *OidcService) ResolveAllowedCallbackURL(ctx context.Context, clientID, inputCallbackURL string) (string, error) {
|
||||
client, err := s.GetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if inputCallbackURL == "" || len(client.CallbackURLs) == 0 {
|
||||
return "", &common.OidcMissingCallbackURLError{}
|
||||
}
|
||||
|
||||
matched, err := utils.GetCallbackURLFromList(client.CallbackURLs, inputCallbackURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if matched == "" {
|
||||
return "", &common.OidcInvalidCallbackURLError{}
|
||||
}
|
||||
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) {
|
||||
var client model.OidcClient
|
||||
q := tx.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import test, { expect } from '@playwright/test';
|
||||
import test, { expect, type Page, type Request } from '@playwright/test';
|
||||
import { oidcClients, refreshTokens, users } from '../data';
|
||||
import { cleanupBackend } from '../utils/cleanup.util';
|
||||
import { generateIdToken, generateOauthAccessToken } from '../utils/jwt.util';
|
||||
@@ -625,31 +625,21 @@ test('Forces reauthentication when client requires it', async ({ page, request }
|
||||
expect(webauthnStartCalled).toBe(true);
|
||||
});
|
||||
|
||||
test('Authorize existing client while not signed in with response_mode=form_post', async ({ page }) => {
|
||||
test('Authorize existing client while not signed in with response_mode=form_post', async ({
|
||||
page
|
||||
}) => {
|
||||
const oidcClient = oidcClients.nextcloud;
|
||||
const urlParams = createUrlParams(oidcClient);
|
||||
urlParams.set('response_mode', 'form_post');
|
||||
await page.context().clearCookies();
|
||||
|
||||
// Track if we receive a POST request to the callback URL
|
||||
let formPostReceived = false;
|
||||
page.on('request', (request) => {
|
||||
if (request.method() === 'POST' && request.url().includes(oidcClient.callbackUrl)) {
|
||||
formPostReceived = true;
|
||||
}
|
||||
});
|
||||
|
||||
const formPostRequestPromise = waitForFormPostRequest(page, oidcClient.callbackUrl);
|
||||
await page.goto(`/authorize?${urlParams.toString()}`);
|
||||
|
||||
await (await passkeyUtil.init(page)).addPasskey();
|
||||
await page.getByRole('button', { name: 'Sign in' }).click();
|
||||
|
||||
// Wait for the POST request to be made to the callback URL
|
||||
await expect(() => {
|
||||
if (!formPostReceived) {
|
||||
throw new Error('Form POST request not received');
|
||||
}
|
||||
}).toBeTruthy();
|
||||
await expectFormPostRequest(formPostRequestPromise);
|
||||
});
|
||||
|
||||
test('Authorize existing client with response_mode=form_post', async ({ page }) => {
|
||||
@@ -657,20 +647,23 @@ test('Authorize existing client with response_mode=form_post', async ({ page })
|
||||
const urlParams = createUrlParams(oidcClient);
|
||||
urlParams.set('response_mode', 'form_post');
|
||||
|
||||
// Track if we receive a POST request to the callback URL
|
||||
let formPostReceived = false;
|
||||
page.on('request', (request) => {
|
||||
if (request.method() === 'POST' && request.url().includes(oidcClient.callbackUrl)) {
|
||||
formPostReceived = true;
|
||||
}
|
||||
});
|
||||
|
||||
const formPostRequestPromise = waitForFormPostRequest(page, oidcClient.callbackUrl);
|
||||
await page.goto(`/authorize?${urlParams.toString()}`);
|
||||
|
||||
// Wait for the POST request to be made to the callback URL
|
||||
await expect(() => {
|
||||
if (!formPostReceived) {
|
||||
throw new Error('Form POST request not received');
|
||||
}
|
||||
}).toBeTruthy();
|
||||
});
|
||||
await expectFormPostRequest(formPostRequestPromise);
|
||||
});
|
||||
|
||||
function waitForFormPostRequest(page: Page, callbackUrl: string): Promise<Request> {
|
||||
return page.waitForRequest(
|
||||
(request) => request.method() === 'POST' && request.url() === callbackUrl
|
||||
);
|
||||
}
|
||||
|
||||
async function expectFormPostRequest(formPostRequestPromise: Promise<Request>) {
|
||||
const request = await formPostRequestPromise;
|
||||
const formData = new URLSearchParams(request.postData() ?? '');
|
||||
|
||||
expect(formData.get('code')).toBeTruthy();
|
||||
expect(formData.get('state')).toBe('nXx-6Qr-owc1SHBa');
|
||||
expect(formData.get('iss')).toBeTruthy();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user