feat: add support for response_mode=form_post (#1360)

Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
John
2026-04-26 15:11:35 +03:00
committed by GitHub
parent 0ed2c48591
commit 64d4ac7919
14 changed files with 387 additions and 42 deletions

View File

@@ -2,8 +2,11 @@
package frontend
import "github.com/gin-gonic/gin"
import (
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc) error {
func RegisterFrontend(router *gin.Engine, oidcService *service.OidcService) error {
return ErrFrontendNotIncluded
}

View File

@@ -10,13 +10,14 @@ import (
"io/fs"
"mime"
"net/http"
"os"
"path"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
"golang.org/x/time/rate"
)
//go:embed all:dist/*
@@ -54,7 +55,7 @@ func init() {
}
}
func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc) error {
func RegisterFrontend(router *gin.Engine, oidcService *service.OidcService) error {
distFS, err := fs.Sub(frontendFS, "dist")
if err != nil {
return fmt.Errorf("failed to create sub FS: %w", err)
@@ -83,16 +84,19 @@ func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc) e
return
}
// If path is / or does not exist, serve index.html
if path == "" {
path = "index.html"
} else if _, err := fs.Stat(distFS, path); os.IsNotExist(err) {
path = "index.html"
}
if path == "index.html" {
if isSPARequest(path, distFS) {
nonce := middleware.GetCSPNonce(c)
if isOAuth2AuthorizationPostRequest(c) {
// In that case, we need to validate and allow form submissions to the redirect_uri
redirectURI := c.Query("redirect_uri")
clientID := c.Query("client_id")
validatedRedirectURI, err := oidcService.ResolveAllowedCallbackURL(c.Request.Context(), clientID, redirectURI)
if err == nil {
c.Header("Content-Security-Policy", middleware.BuildCSP(nonce, validatedRedirectURI))
}
}
// Do not cache the HTML shell, as it embeds a per-request nonce
c.Header("Content-Type", "text/html; charset=utf-8")
c.Header("Cache-Control", "no-store")
@@ -108,11 +112,46 @@ func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc) e
fileServer.ServeHTTP(c.Writer, c.Request)
}
router.NoRoute(rateLimitMiddleware, handler)
rateLimitMiddleware := middleware.NewRateLimitMiddleware().Add(rate.Every(300*time.Millisecond), 50)
router.NoRoute(rateLimitOnlyForOAuth2AuthorizationPostRequest(rateLimitMiddleware, distFS), handler)
return nil
}
func rateLimitOnlyForOAuth2AuthorizationPostRequest(rateLimitMiddleware gin.HandlerFunc, distFS fs.FS) gin.HandlerFunc {
return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/")
if isSPARequest(path, distFS) && isOAuth2AuthorizationPostRequest(c) {
rateLimitMiddleware(c)
return
}
c.Next()
}
}
// isOAuth2AuthorizationRequest checks if this is an OAuth2 authorization request with response_mode=form_post
// In that case, we need to validate and allow form submissions to the redirect_uri
func isOAuth2AuthorizationPostRequest(c *gin.Context) bool {
responseMode := c.Query("response_mode")
redirectURI := c.Query("redirect_uri")
clientID := c.Query("client_id")
return responseMode == "form_post" && redirectURI != "" && clientID != ""
}
func isSPARequest(path string, distFS fs.FS) bool {
if path == "" {
return true
}
if _, err := fs.Stat(distFS, path); err != nil {
return true
}
return false
}
// FileServerWithCaching wraps http.FileServer to add caching headers
type FileServerWithCaching struct {
root http.FileSystem

View File

@@ -0,0 +1,111 @@
//go:build !exclude_frontend
package frontend
import (
"net/http"
"net/http/httptest"
"testing"
"testing/fstest"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestIsSPARequest(t *testing.T) {
distFS := fstest.MapFS{
"assets/app.js": &fstest.MapFile{Data: []byte("console.log('test')")},
}
t.Run("root path is spa request", func(t *testing.T) {
assert.True(t, isSPARequest("", distFS))
})
t.Run("existing bundled asset is not spa request", func(t *testing.T) {
assert.False(t, isSPARequest("assets/app.js", distFS))
})
t.Run("unknown path is spa request", func(t *testing.T) {
assert.True(t, isSPARequest("authorize", distFS))
})
}
func TestRateLimitOnlyForOAuth2AuthorizationPostRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS := fstest.MapFS{
"assets/app.js": &fstest.MapFile{Data: []byte("console.log('test')")},
}
t.Run("rate limits spa form_post request", func(t *testing.T) {
rateLimited := false
nextCalled := false
middleware := rateLimitOnlyForOAuth2AuthorizationPostRequest(func(c *gin.Context) {
rateLimited = true
c.Abort()
}, distFS)
router := gin.New()
router.NoRoute(
middleware,
func(c *gin.Context) {
nextCalled = true
},
)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/authorize?response_mode=form_post&client_id=test&redirect_uri=https://example.com/callback", nil)
router.ServeHTTP(recorder, req)
assert.True(t, rateLimited)
assert.False(t, nextCalled)
})
t.Run("does not rate limit page request with no form_post params", func(t *testing.T) {
rateLimited := false
nextCalled := false
middleware := rateLimitOnlyForOAuth2AuthorizationPostRequest(func(c *gin.Context) {
rateLimited = true
c.Abort()
}, distFS)
router := gin.New()
router.NoRoute(
middleware,
func(c *gin.Context) {
nextCalled = true
},
)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/authorize", nil)
router.ServeHTTP(recorder, req)
assert.False(t, rateLimited)
assert.True(t, nextCalled)
})
t.Run("does not rate limit static asset request with form_post params", func(t *testing.T) {
rateLimited := false
nextCalled := false
middleware := rateLimitOnlyForOAuth2AuthorizationPostRequest(func(c *gin.Context) {
rateLimited = true
c.Abort()
}, distFS)
router := gin.New()
router.NoRoute(
middleware,
func(c *gin.Context) {
nextCalled = true
},
)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/assets/app.js?response_mode=form_post&client_id=test&redirect_uri=https://example.com/callback", nil)
router.ServeHTTP(recorder, req)
assert.False(t, rateLimited)
assert.True(t, nextCalled)
})
}

View File

@@ -40,7 +40,10 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) {
if err != nil {
return nil, err
}
registerRoutes(r, db, svc)
err = registerRoutes(r, db, svc)
if err != nil {
return nil, err
}
serverConfig, err := initServer(r)
if err != nil {
@@ -70,15 +73,6 @@ func initEngine() (*gin.Engine, error) {
configureEngine(r)
registerGlobalMiddleware(r)
frontendRateLimitMiddleware := middleware.NewRateLimitMiddleware().Add(rate.Every(100*time.Millisecond), 300)
if err := frontend.RegisterFrontend(r, frontendRateLimitMiddleware); err != nil {
if errors.Is(err, frontend.ErrFrontendNotIncluded) {
slog.Warn("Frontend is not included in the build. Skipping frontend registration.")
return r, nil
}
return nil, fmt.Errorf("failed to register frontend: %w", err)
}
return r, nil
}
@@ -116,7 +110,16 @@ func registerGlobalMiddleware(r *gin.Engine) {
r.Use(middleware.NewErrorHandlerMiddleware().Add())
}
func registerRoutes(r *gin.Engine, db *gorm.DB, svc *services) {
func registerRoutes(r *gin.Engine, db *gorm.DB, svc *services) error {
err := frontend.RegisterFrontend(r, svc.oidcService)
if errors.Is(err, frontend.ErrFrontendNotIncluded) {
slog.Warn("Frontend is not included in the build. Skipping frontend registration.")
} else if err != nil {
return fmt.Errorf("failed to register frontend: %w", err)
}
// Initialize middleware for specific routes
authMiddleware := middleware.NewAuthMiddleware(svc.apiKeyService, svc.userService, svc.jwtService)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
apiRateLimitMiddleware := middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 100)
@@ -142,6 +145,8 @@ func registerRoutes(r *gin.Engine, db *gorm.DB, svc *services) {
// These are not rate-limited.
controller.NewHealthzController(r)
return nil
}
func registerTestRoutes(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {

View File

@@ -72,6 +72,7 @@ type AuthorizeOidcClientRequestDto struct {
CodeChallengeMethod string `json:"codeChallengeMethod"`
ReauthenticationToken string `json:"reauthenticationToken"`
Prompt string `json:"prompt"`
ResponseMode string `json:"responseMode" binding:"omitempty,response_mode"`
}
type AuthorizeOidcClientResponseDto struct {

View File

@@ -47,6 +47,9 @@ func init() {
"callback_url_pattern": func(fl validator.FieldLevel) bool {
return ValidateCallbackURLPattern(fl.Field().String())
},
"response_mode": func(fl validator.FieldLevel) bool {
return ValidateResponseMode(fl.Field().String())
},
}
for k, v := range validators {
err := engine.RegisterValidation(k, v)
@@ -87,3 +90,17 @@ func ValidateCallbackURLPattern(raw string) bool {
err := utils.ValidateCallbackURLPattern(raw)
return err == nil
}
// ValidateResponseMode validates response_mode parameter
// If responseMode is present, it must be "form_post" or "query"
// Empty responseMode is allowed (field not provided, use default)
func ValidateResponseMode(responseMode string) bool {
switch responseMode {
case "form_post", "query":
return true
case "":
return true
default:
return false
}
}

View File

@@ -58,6 +58,25 @@ func TestValidateClientID(t *testing.T) {
}
}
func TestValidateResponseMode(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"valid form_post", "form_post", true},
{"valid query", "query", true},
{"valid empty", "", true},
{"invalid fragment", "fragment", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, ValidateResponseMode(tt.input))
})
}
}
func TestValidateCallbackURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -3,6 +3,7 @@ package middleware
import (
"crypto/rand"
"encoding/base64"
"strings"
"github.com/gin-gonic/gin"
)
@@ -28,22 +29,39 @@ func (m *CspMiddleware) Add() gin.HandlerFunc {
// Generate a random base64 nonce for this request
nonce := generateNonce()
c.Set("csp_nonce", nonce)
c.Writer.Header().Set("Content-Security-Policy", BuildCSP(nonce))
csp := "default-src 'self'; " +
"base-uri 'self'; " +
"object-src 'none'; " +
"frame-ancestors 'none'; " +
"form-action 'self'; " +
"img-src * blob:;" +
"font-src 'self'; " +
"style-src 'self' 'unsafe-inline'; " +
"script-src 'self' 'nonce-" + nonce + "'"
c.Writer.Header().Set("Content-Security-Policy", csp)
c.Next()
}
}
func BuildCSP(nonce string, formActionExtra ...string) string {
formAction := "'self'"
if len(formActionExtra) > 0 {
b := strings.Builder{}
for _, extra := range formActionExtra {
if extra != "" {
b.WriteByte(' ')
b.WriteString(extra)
}
}
formAction += b.String()
}
return "default-src 'self'; " +
"base-uri 'self'; " +
"object-src 'none'; " +
"frame-ancestors 'none'; " +
"form-action " + formAction + "; " +
"img-src * blob:;" +
"font-src 'self'; " +
"style-src 'self' 'unsafe-inline'; " +
"script-src 'self' 'nonce-" + nonce + "'"
}
func generateNonce() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {

View File

@@ -0,0 +1,24 @@
package middleware
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBuildCSP(t *testing.T) {
t.Run("uses self form action by default", func(t *testing.T) {
csp := BuildCSP("test-nonce")
assert.Contains(t, csp, "form-action 'self';")
assert.Contains(t, csp, "script-src 'self' 'nonce-test-nonce'")
})
t.Run("adds validated form action targets", func(t *testing.T) {
csp := BuildCSP("test-nonce", "https://example.com/callback")
assert.Contains(t, csp, "form-action 'self' https://example.com/callback;")
assert.Equal(t, 1, strings.Count(csp, "form-action"))
})
}

View File

@@ -754,6 +754,27 @@ func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.Oid
return s.getClientInternal(ctx, clientID, s.db, false)
}
func (s *OidcService) ResolveAllowedCallbackURL(ctx context.Context, clientID, inputCallbackURL string) (string, error) {
client, err := s.GetClient(ctx, clientID)
if err != nil {
return "", err
}
if inputCallbackURL == "" || len(client.CallbackURLs) == 0 {
return "", &common.OidcMissingCallbackURLError{}
}
matched, err := utils.GetCallbackURLFromList(client.CallbackURLs, inputCallbackURL)
if err != nil {
return "", err
}
if matched == "" {
return "", &common.OidcInvalidCallbackURLError{}
}
return matched, nil
}
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) {
var client model.OidcClient
q := tx.

View File

@@ -23,6 +23,7 @@ class OidcService extends APIService {
codeChallenge?: string,
codeChallengeMethod?: string,
reauthenticationToken?: string,
responseMode?: string,
prompt?: string
) => {
const res = await this.api.post('/oidc/authorize', {
@@ -33,6 +34,7 @@ class OidcService extends APIService {
codeChallenge,
codeChallengeMethod,
reauthenticationToken,
responseMode,
prompt
});

View File

@@ -28,7 +28,8 @@
codeChallenge,
codeChallengeMethod,
authorizeState,
prompt
prompt,
responseMode
} = data;
let isLoading = $state(false);
@@ -118,6 +119,7 @@
codeChallenge,
codeChallengeMethod,
reauthToken,
responseMode,
prompt
);
@@ -164,7 +166,46 @@
success = true;
setTimeout(() => {
window.location.href = redirectURL.toString();
if (responseMode === 'form_post') {
// Create a hidden form and submit it via POST
const form = document.createElement('form');
form.method = 'POST';
form.action = callbackURL;
// Add code parameter
const codeInput = document.createElement('input');
codeInput.type = 'hidden';
codeInput.name = 'code';
codeInput.value = code;
form.appendChild(codeInput);
// Add state parameter
if (authorizeState) {
const stateInput = document.createElement('input');
stateInput.type = 'hidden';
stateInput.name = 'state';
stateInput.value = authorizeState;
form.appendChild(stateInput);
}
// Add issuer parameter
const issInput = document.createElement('input');
issInput.type = 'hidden';
issInput.name = 'iss';
issInput.value = issuer;
form.appendChild(issInput);
document.body.appendChild(form);
form.submit();
} else {
// Default query parameter redirect (response_mode=query or not specified)
const redirectURL = new URL(callbackURL);
redirectURL.searchParams.append('code', code);
redirectURL.searchParams.append('state', authorizeState);
redirectURL.searchParams.append('iss', issuer);
window.location.href = redirectURL.toString();
}
}, 1000);
}
</script>
@@ -196,7 +237,7 @@
/>
</p>
{:else if authorizationRequired}
<div class="w-full max-w-[450px]" transition:slide={{ duration: 300 }}>
<div class="w-full max-w-md" transition:slide={{ duration: 300 }}>
<Card.Root class="mt-6 mb-10">
<Card.Header>
<p class="text-muted-foreground text-start">
@@ -212,7 +253,7 @@
</div>
{/if}
<!-- Flex flow is reversed so the sign in button, which has auto-focus, is the first one in the DOM, for a11y -->
<div class="flex w-full max-w-[450px] flex-row-reverse gap-2">
<div class="flex w-full max-w-md flex-row-reverse gap-2">
{#if !errorMessage}
<Button class="flex-1" {isLoading} onclick={authorize} autofocus={true}>
{m.sign_in()}

View File

@@ -15,6 +15,7 @@ export const load: PageLoad = async ({ url }) => {
client,
codeChallenge: url.searchParams.get('code_challenge')!,
codeChallengeMethod: url.searchParams.get('code_challenge_method')!,
prompt: url.searchParams.get('prompt') || undefined
prompt: url.searchParams.get('prompt') || undefined,
responseMode: url.searchParams.get('response_mode') || undefined
};
};

View File

@@ -1,4 +1,4 @@
import test, { expect } from '@playwright/test';
import test, { expect, type Page, type Request } from '@playwright/test';
import { oidcClients, refreshTokens, users } from '../data';
import { cleanupBackend } from '../utils/cleanup.util';
import { generateIdToken, generateOauthAccessToken } from '../utils/jwt.util';
@@ -625,6 +625,49 @@ test('Forces reauthentication when client requires it', async ({ page, request }
expect(webauthnStartCalled).toBe(true);
});
test('Authorize existing client while not signed in with response_mode=form_post', async ({
page
}) => {
const oidcClient = oidcClients.nextcloud;
const urlParams = createUrlParams(oidcClient);
urlParams.set('response_mode', 'form_post');
await page.context().clearCookies();
const formPostRequestPromise = waitForFormPostRequest(page, oidcClient.callbackUrl);
await page.goto(`/authorize?${urlParams.toString()}`);
await (await passkeyUtil.init(page)).addPasskey();
await page.getByRole('button', { name: 'Sign in' }).click();
await expectFormPostRequest(formPostRequestPromise);
});
test('Authorize existing client with response_mode=form_post', async ({ page }) => {
const oidcClient = oidcClients.nextcloud;
const urlParams = createUrlParams(oidcClient);
urlParams.set('response_mode', 'form_post');
const formPostRequestPromise = waitForFormPostRequest(page, oidcClient.callbackUrl);
await page.goto(`/authorize?${urlParams.toString()}`);
await expectFormPostRequest(formPostRequestPromise);
});
function waitForFormPostRequest(page: Page, callbackUrl: string): Promise<Request> {
return page.waitForRequest(
(request) => request.method() === 'POST' && request.url() === callbackUrl
);
}
async function expectFormPostRequest(formPostRequestPromise: Promise<Request>) {
const request = await formPostRequestPromise;
const formData = new URLSearchParams(request.postData() ?? '');
expect(formData.get('code')).toBeTruthy();
expect(formData.get('state')).toBe('nXx-6Qr-owc1SHBa');
expect(formData.get('iss')).toBeTruthy();
}
test.describe('OIDC prompt parameter', () => {
test('prompt=none redirects with login_required when user not authenticated', async ({
page
@@ -783,4 +826,4 @@ test.describe('OIDC prompt parameter', () => {
expect(redirectUrl.searchParams.get('error')).toBe('interaction_required');
expect(redirectUrl.searchParams.get('state')).toBe('nXx-6Qr-owc1SHBa');
});
});
});