diff --git a/backend/frontend/frontend_excluded.go b/backend/frontend/frontend_excluded.go index 2e0e35ef..d7c652f6 100644 --- a/backend/frontend/frontend_excluded.go +++ b/backend/frontend/frontend_excluded.go @@ -2,8 +2,11 @@ package frontend -import "github.com/gin-gonic/gin" +import ( + "github.com/gin-gonic/gin" + "github.com/pocket-id/pocket-id/backend/internal/service" +) -func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc) error { +func RegisterFrontend(router *gin.Engine, oidcService *service.OidcService) error { return ErrFrontendNotIncluded } diff --git a/backend/frontend/frontend_included.go b/backend/frontend/frontend_included.go index 6486671b..c7eb88a6 100644 --- a/backend/frontend/frontend_included.go +++ b/backend/frontend/frontend_included.go @@ -10,13 +10,14 @@ import ( "io/fs" "mime" "net/http" - "os" "path" "strings" "time" "github.com/gin-gonic/gin" "github.com/pocket-id/pocket-id/backend/internal/middleware" + "github.com/pocket-id/pocket-id/backend/internal/service" + "golang.org/x/time/rate" ) //go:embed all:dist/* @@ -54,7 +55,7 @@ func init() { } } -func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc) error { +func RegisterFrontend(router *gin.Engine, oidcService *service.OidcService) error { distFS, err := fs.Sub(frontendFS, "dist") if err != nil { return fmt.Errorf("failed to create sub FS: %w", err) @@ -83,16 +84,19 @@ func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc) e return } - // If path is / or does not exist, serve index.html - if path == "" { - path = "index.html" - } else if _, err := fs.Stat(distFS, path); os.IsNotExist(err) { - path = "index.html" - } - - if path == "index.html" { + if isSPARequest(path, distFS) { nonce := middleware.GetCSPNonce(c) + if isOAuth2AuthorizationPostRequest(c) { + // In that case, we need to validate and allow form submissions to the redirect_uri + redirectURI := c.Query("redirect_uri") + clientID := c.Query("client_id") + validatedRedirectURI, err := oidcService.ResolveAllowedCallbackURL(c.Request.Context(), clientID, redirectURI) + if err == nil { + c.Header("Content-Security-Policy", middleware.BuildCSP(nonce, validatedRedirectURI)) + } + } + // Do not cache the HTML shell, as it embeds a per-request nonce c.Header("Content-Type", "text/html; charset=utf-8") c.Header("Cache-Control", "no-store") @@ -108,11 +112,46 @@ func RegisterFrontend(router *gin.Engine, rateLimitMiddleware gin.HandlerFunc) e fileServer.ServeHTTP(c.Writer, c.Request) } - router.NoRoute(rateLimitMiddleware, handler) + rateLimitMiddleware := middleware.NewRateLimitMiddleware().Add(rate.Every(300*time.Millisecond), 50) + router.NoRoute(rateLimitOnlyForOAuth2AuthorizationPostRequest(rateLimitMiddleware, distFS), handler) return nil } +func rateLimitOnlyForOAuth2AuthorizationPostRequest(rateLimitMiddleware gin.HandlerFunc, distFS fs.FS) gin.HandlerFunc { + return func(c *gin.Context) { + path := strings.TrimPrefix(c.Request.URL.Path, "/") + if isSPARequest(path, distFS) && isOAuth2AuthorizationPostRequest(c) { + rateLimitMiddleware(c) + return + } + + c.Next() + } +} + +// isOAuth2AuthorizationRequest checks if this is an OAuth2 authorization request with response_mode=form_post +// In that case, we need to validate and allow form submissions to the redirect_uri +func isOAuth2AuthorizationPostRequest(c *gin.Context) bool { + responseMode := c.Query("response_mode") + redirectURI := c.Query("redirect_uri") + clientID := c.Query("client_id") + + return responseMode == "form_post" && redirectURI != "" && clientID != "" +} + +func isSPARequest(path string, distFS fs.FS) bool { + if path == "" { + return true + } + + if _, err := fs.Stat(distFS, path); err != nil { + return true + } + + return false +} + // FileServerWithCaching wraps http.FileServer to add caching headers type FileServerWithCaching struct { root http.FileSystem diff --git a/backend/frontend/frontend_included_test.go b/backend/frontend/frontend_included_test.go new file mode 100644 index 00000000..5629ca09 --- /dev/null +++ b/backend/frontend/frontend_included_test.go @@ -0,0 +1,111 @@ +//go:build !exclude_frontend + +package frontend + +import ( + "net/http" + "net/http/httptest" + "testing" + "testing/fstest" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestIsSPARequest(t *testing.T) { + distFS := fstest.MapFS{ + "assets/app.js": &fstest.MapFile{Data: []byte("console.log('test')")}, + } + + t.Run("root path is spa request", func(t *testing.T) { + assert.True(t, isSPARequest("", distFS)) + }) + + t.Run("existing bundled asset is not spa request", func(t *testing.T) { + assert.False(t, isSPARequest("assets/app.js", distFS)) + }) + + t.Run("unknown path is spa request", func(t *testing.T) { + assert.True(t, isSPARequest("authorize", distFS)) + }) +} + +func TestRateLimitOnlyForOAuth2AuthorizationPostRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + distFS := fstest.MapFS{ + "assets/app.js": &fstest.MapFile{Data: []byte("console.log('test')")}, + } + + t.Run("rate limits spa form_post request", func(t *testing.T) { + rateLimited := false + nextCalled := false + middleware := rateLimitOnlyForOAuth2AuthorizationPostRequest(func(c *gin.Context) { + rateLimited = true + c.Abort() + }, distFS) + + router := gin.New() + router.NoRoute( + middleware, + func(c *gin.Context) { + nextCalled = true + }, + ) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/authorize?response_mode=form_post&client_id=test&redirect_uri=https://example.com/callback", nil) + router.ServeHTTP(recorder, req) + + assert.True(t, rateLimited) + assert.False(t, nextCalled) + }) + + t.Run("does not rate limit page request with no form_post params", func(t *testing.T) { + rateLimited := false + nextCalled := false + middleware := rateLimitOnlyForOAuth2AuthorizationPostRequest(func(c *gin.Context) { + rateLimited = true + c.Abort() + }, distFS) + + router := gin.New() + router.NoRoute( + middleware, + func(c *gin.Context) { + nextCalled = true + }, + ) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/authorize", nil) + router.ServeHTTP(recorder, req) + + assert.False(t, rateLimited) + assert.True(t, nextCalled) + }) + + t.Run("does not rate limit static asset request with form_post params", func(t *testing.T) { + rateLimited := false + nextCalled := false + middleware := rateLimitOnlyForOAuth2AuthorizationPostRequest(func(c *gin.Context) { + rateLimited = true + c.Abort() + }, distFS) + + router := gin.New() + router.NoRoute( + middleware, + func(c *gin.Context) { + nextCalled = true + }, + ) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/assets/app.js?response_mode=form_post&client_id=test&redirect_uri=https://example.com/callback", nil) + router.ServeHTTP(recorder, req) + + assert.False(t, rateLimited) + assert.True(t, nextCalled) + }) +} diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index 4cf7e3ea..30b3521e 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -40,7 +40,10 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) { if err != nil { return nil, err } - registerRoutes(r, db, svc) + err = registerRoutes(r, db, svc) + if err != nil { + return nil, err + } serverConfig, err := initServer(r) if err != nil { @@ -70,15 +73,6 @@ func initEngine() (*gin.Engine, error) { configureEngine(r) registerGlobalMiddleware(r) - frontendRateLimitMiddleware := middleware.NewRateLimitMiddleware().Add(rate.Every(100*time.Millisecond), 300) - if err := frontend.RegisterFrontend(r, frontendRateLimitMiddleware); err != nil { - if errors.Is(err, frontend.ErrFrontendNotIncluded) { - slog.Warn("Frontend is not included in the build. Skipping frontend registration.") - return r, nil - } - return nil, fmt.Errorf("failed to register frontend: %w", err) - } - return r, nil } @@ -116,7 +110,16 @@ func registerGlobalMiddleware(r *gin.Engine) { r.Use(middleware.NewErrorHandlerMiddleware().Add()) } -func registerRoutes(r *gin.Engine, db *gorm.DB, svc *services) { +func registerRoutes(r *gin.Engine, db *gorm.DB, svc *services) error { + + err := frontend.RegisterFrontend(r, svc.oidcService) + if errors.Is(err, frontend.ErrFrontendNotIncluded) { + slog.Warn("Frontend is not included in the build. Skipping frontend registration.") + } else if err != nil { + return fmt.Errorf("failed to register frontend: %w", err) + } + + // Initialize middleware for specific routes authMiddleware := middleware.NewAuthMiddleware(svc.apiKeyService, svc.userService, svc.jwtService) fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware() apiRateLimitMiddleware := middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 100) @@ -142,6 +145,8 @@ func registerRoutes(r *gin.Engine, db *gorm.DB, svc *services) { // These are not rate-limited. controller.NewHealthzController(r) + + return nil } func registerTestRoutes(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) { diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index 4d8187e0..cc4a5b51 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -72,6 +72,7 @@ type AuthorizeOidcClientRequestDto struct { CodeChallengeMethod string `json:"codeChallengeMethod"` ReauthenticationToken string `json:"reauthenticationToken"` Prompt string `json:"prompt"` + ResponseMode string `json:"responseMode" binding:"omitempty,response_mode"` } type AuthorizeOidcClientResponseDto struct { diff --git a/backend/internal/dto/validations.go b/backend/internal/dto/validations.go index 8408df0f..2380775c 100644 --- a/backend/internal/dto/validations.go +++ b/backend/internal/dto/validations.go @@ -47,6 +47,9 @@ func init() { "callback_url_pattern": func(fl validator.FieldLevel) bool { return ValidateCallbackURLPattern(fl.Field().String()) }, + "response_mode": func(fl validator.FieldLevel) bool { + return ValidateResponseMode(fl.Field().String()) + }, } for k, v := range validators { err := engine.RegisterValidation(k, v) @@ -87,3 +90,17 @@ func ValidateCallbackURLPattern(raw string) bool { err := utils.ValidateCallbackURLPattern(raw) return err == nil } + +// ValidateResponseMode validates response_mode parameter +// If responseMode is present, it must be "form_post" or "query" +// Empty responseMode is allowed (field not provided, use default) +func ValidateResponseMode(responseMode string) bool { + switch responseMode { + case "form_post", "query": + return true + case "": + return true + default: + return false + } +} diff --git a/backend/internal/dto/validations_test.go b/backend/internal/dto/validations_test.go index 9ef2af78..5f9d595d 100644 --- a/backend/internal/dto/validations_test.go +++ b/backend/internal/dto/validations_test.go @@ -58,6 +58,25 @@ func TestValidateClientID(t *testing.T) { } } +func TestValidateResponseMode(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"valid form_post", "form_post", true}, + {"valid query", "query", true}, + {"valid empty", "", true}, + {"invalid fragment", "fragment", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, ValidateResponseMode(tt.input)) + }) + } +} + func TestValidateCallbackURL(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/middleware/csp_middleware.go b/backend/internal/middleware/csp_middleware.go index b39b13f1..ae64528e 100644 --- a/backend/internal/middleware/csp_middleware.go +++ b/backend/internal/middleware/csp_middleware.go @@ -3,6 +3,7 @@ package middleware import ( "crypto/rand" "encoding/base64" + "strings" "github.com/gin-gonic/gin" ) @@ -28,22 +29,39 @@ func (m *CspMiddleware) Add() gin.HandlerFunc { // Generate a random base64 nonce for this request nonce := generateNonce() c.Set("csp_nonce", nonce) + c.Writer.Header().Set("Content-Security-Policy", BuildCSP(nonce)) - csp := "default-src 'self'; " + - "base-uri 'self'; " + - "object-src 'none'; " + - "frame-ancestors 'none'; " + - "form-action 'self'; " + - "img-src * blob:;" + - "font-src 'self'; " + - "style-src 'self' 'unsafe-inline'; " + - "script-src 'self' 'nonce-" + nonce + "'" - - c.Writer.Header().Set("Content-Security-Policy", csp) c.Next() } } +func BuildCSP(nonce string, formActionExtra ...string) string { + formAction := "'self'" + + if len(formActionExtra) > 0 { + b := strings.Builder{} + + for _, extra := range formActionExtra { + if extra != "" { + b.WriteByte(' ') + b.WriteString(extra) + } + } + + formAction += b.String() + } + + return "default-src 'self'; " + + "base-uri 'self'; " + + "object-src 'none'; " + + "frame-ancestors 'none'; " + + "form-action " + formAction + "; " + + "img-src * blob:;" + + "font-src 'self'; " + + "style-src 'self' 'unsafe-inline'; " + + "script-src 'self' 'nonce-" + nonce + "'" +} + func generateNonce() string { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { diff --git a/backend/internal/middleware/csp_middleware_test.go b/backend/internal/middleware/csp_middleware_test.go new file mode 100644 index 00000000..67f90464 --- /dev/null +++ b/backend/internal/middleware/csp_middleware_test.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuildCSP(t *testing.T) { + t.Run("uses self form action by default", func(t *testing.T) { + csp := BuildCSP("test-nonce") + + assert.Contains(t, csp, "form-action 'self';") + assert.Contains(t, csp, "script-src 'self' 'nonce-test-nonce'") + }) + + t.Run("adds validated form action targets", func(t *testing.T) { + csp := BuildCSP("test-nonce", "https://example.com/callback") + + assert.Contains(t, csp, "form-action 'self' https://example.com/callback;") + assert.Equal(t, 1, strings.Count(csp, "form-action")) + }) +} diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 6c09cb8c..53824588 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -754,6 +754,27 @@ func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.Oid return s.getClientInternal(ctx, clientID, s.db, false) } +func (s *OidcService) ResolveAllowedCallbackURL(ctx context.Context, clientID, inputCallbackURL string) (string, error) { + client, err := s.GetClient(ctx, clientID) + if err != nil { + return "", err + } + + if inputCallbackURL == "" || len(client.CallbackURLs) == 0 { + return "", &common.OidcMissingCallbackURLError{} + } + + matched, err := utils.GetCallbackURLFromList(client.CallbackURLs, inputCallbackURL) + if err != nil { + return "", err + } + if matched == "" { + return "", &common.OidcInvalidCallbackURLError{} + } + + return matched, nil +} + func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB, forUpdate bool) (model.OidcClient, error) { var client model.OidcClient q := tx. diff --git a/frontend/src/lib/services/oidc-service.ts b/frontend/src/lib/services/oidc-service.ts index abed878a..34579a53 100644 --- a/frontend/src/lib/services/oidc-service.ts +++ b/frontend/src/lib/services/oidc-service.ts @@ -23,6 +23,7 @@ class OidcService extends APIService { codeChallenge?: string, codeChallengeMethod?: string, reauthenticationToken?: string, + responseMode?: string, prompt?: string ) => { const res = await this.api.post('/oidc/authorize', { @@ -33,6 +34,7 @@ class OidcService extends APIService { codeChallenge, codeChallengeMethod, reauthenticationToken, + responseMode, prompt }); diff --git a/frontend/src/routes/authorize/+page.svelte b/frontend/src/routes/authorize/+page.svelte index e9a99ba2..e35285ae 100644 --- a/frontend/src/routes/authorize/+page.svelte +++ b/frontend/src/routes/authorize/+page.svelte @@ -28,7 +28,8 @@ codeChallenge, codeChallengeMethod, authorizeState, - prompt + prompt, + responseMode } = data; let isLoading = $state(false); @@ -118,6 +119,7 @@ codeChallenge, codeChallengeMethod, reauthToken, + responseMode, prompt ); @@ -164,7 +166,46 @@ success = true; setTimeout(() => { - window.location.href = redirectURL.toString(); + if (responseMode === 'form_post') { + // Create a hidden form and submit it via POST + const form = document.createElement('form'); + form.method = 'POST'; + form.action = callbackURL; + + // Add code parameter + const codeInput = document.createElement('input'); + codeInput.type = 'hidden'; + codeInput.name = 'code'; + codeInput.value = code; + form.appendChild(codeInput); + + // Add state parameter + if (authorizeState) { + const stateInput = document.createElement('input'); + stateInput.type = 'hidden'; + stateInput.name = 'state'; + stateInput.value = authorizeState; + form.appendChild(stateInput); + } + + // Add issuer parameter + const issInput = document.createElement('input'); + issInput.type = 'hidden'; + issInput.name = 'iss'; + issInput.value = issuer; + form.appendChild(issInput); + + document.body.appendChild(form); + form.submit(); + } else { + // Default query parameter redirect (response_mode=query or not specified) + const redirectURL = new URL(callbackURL); + redirectURL.searchParams.append('code', code); + redirectURL.searchParams.append('state', authorizeState); + redirectURL.searchParams.append('iss', issuer); + + window.location.href = redirectURL.toString(); + } }, 1000); } @@ -196,7 +237,7 @@ />
{:else if authorizationRequired} -@@ -212,7 +253,7 @@