diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 193a6723..669fb0a7 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -89,12 +89,19 @@ type OidcController struct { // @Router /api/oidc/authorize [post] func (oc *OidcController) authorizeHandler(c *gin.Context) { var input dto.AuthorizeOidcClientRequestDto - if err := c.ShouldBindJSON(&input); err != nil { + err := c.ShouldBindJSON(&input) + if err != nil { _ = c.Error(err) return } - code, callbackURL, err := oc.oidcService.Authorize(c.Request.Context(), input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) + code, callbackURL, err := oc.oidcService.Authorize( + c.Request.Context(), + input, + c.GetString("userID"), + c.ClientIP(), + c.Request.UserAgent(), + ) if err != nil { _ = c.Error(err) return diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index e6a186a6..350734a8 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -33,8 +33,8 @@ type OidcClientWithAllowedGroupsCountDto struct { type OidcClientUpdateDto struct { Name string `json:"name" binding:"required,max=50" unorm:"nfc"` - CallbackURLs []string `json:"callbackURLs" binding:"omitempty,dive,callback_url"` - LogoutCallbackURLs []string `json:"logoutCallbackURLs" binding:"omitempty,dive,callback_url"` + CallbackURLs []string `json:"callbackURLs" binding:"omitempty,dive,callback_url_pattern"` + LogoutCallbackURLs []string `json:"logoutCallbackURLs" binding:"omitempty,dive,callback_url_pattern"` IsPublic bool `json:"isPublic"` PkceEnabled bool `json:"pkceEnabled"` RequiresReauthentication bool `json:"requiresReauthentication"` @@ -66,7 +66,7 @@ type OidcClientFederatedIdentityDto struct { type AuthorizeOidcClientRequestDto struct { ClientID string `json:"clientID" binding:"required"` Scope string `json:"scope" binding:"required"` - CallbackURL string `json:"callbackURL"` + CallbackURL string `json:"callbackURL" binding:"omitempty,callback_url"` Nonce string `json:"nonce"` CodeChallenge string `json:"codeChallenge"` CodeChallengeMethod string `json:"codeChallengeMethod"` diff --git a/backend/internal/dto/validations.go b/backend/internal/dto/validations.go index a40973e0..8408df0f 100644 --- a/backend/internal/dto/validations.go +++ b/backend/internal/dto/validations.go @@ -1,7 +1,9 @@ package dto import ( + "net/url" "regexp" + "strings" "time" "github.com/pocket-id/pocket-id/backend/internal/utils" @@ -19,38 +21,38 @@ var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9]([a-zA-Z0-9_.@-]*[a- var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$") func init() { - v := binding.Validator.Engine().(*validator.Validate) + engine := binding.Validator.Engine().(*validator.Validate) // Maximum allowed value for TTLs const maxTTL = 31 * 24 * time.Hour - if err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool { - return ValidateUsername(fl.Field().String()) - }); err != nil { - panic("Failed to register custom validation for username: " + err.Error()) + validators := map[string]validator.Func{ + "username": func(fl validator.FieldLevel) bool { + return ValidateUsername(fl.Field().String()) + }, + "client_id": func(fl validator.FieldLevel) bool { + return ValidateClientID(fl.Field().String()) + }, + "ttl": func(fl validator.FieldLevel) bool { + ttl, ok := fl.Field().Interface().(utils.JSONDuration) + if !ok { + return false + } + // Allow zero, which means the field wasn't set + return ttl.Duration == 0 || (ttl.Duration > time.Second && ttl.Duration <= maxTTL) + }, + "callback_url": func(fl validator.FieldLevel) bool { + return ValidateCallbackURL(fl.Field().String()) + }, + "callback_url_pattern": func(fl validator.FieldLevel) bool { + return ValidateCallbackURLPattern(fl.Field().String()) + }, } - - if err := v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool { - return ValidateClientID(fl.Field().String()) - }); err != nil { - panic("Failed to register custom validation for client_id: " + err.Error()) - } - - if err := v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool { - ttl, ok := fl.Field().Interface().(utils.JSONDuration) - if !ok { - return false + for k, v := range validators { + err := engine.RegisterValidation(k, v) + if err != nil { + panic("Failed to register custom validation for " + k + ": " + err.Error()) } - // Allow zero, which means the field wasn't set - return ttl.Duration == 0 || (ttl.Duration > time.Second && ttl.Duration <= maxTTL) - }); err != nil { - panic("Failed to register custom validation for ttl: " + err.Error()) - } - - if err := v.RegisterValidation("callback_url", func(fl validator.FieldLevel) bool { - return ValidateCallbackURL(fl.Field().String()) - }); err != nil { - panic("Failed to register custom validation for callback_url: " + err.Error()) } } @@ -64,8 +66,24 @@ func ValidateClientID(clientID string) bool { return validateClientIDRegex.MatchString(clientID) } -// ValidateCallbackURL validates callback URLs with support for wildcards -func ValidateCallbackURL(raw string) bool { +// ValidateCallbackURL validates the input callback URL +func ValidateCallbackURL(str string) bool { + // Ensure the URL is a valid one and that the protocol is not "javascript:" or "data:" + u, err := url.Parse(str) + if err != nil { + return false + } + + switch strings.ToLower(u.Scheme) { + case "javascript", "data": + return false + default: + return true + } +} + +// ValidateCallbackURLPattern validates callback URL patterns, with support for wildcards +func ValidateCallbackURLPattern(raw string) bool { err := utils.ValidateCallbackURLPattern(raw) return err == nil } diff --git a/backend/internal/dto/validations_test.go b/backend/internal/dto/validations_test.go index b91b3292..9ef2af78 100644 --- a/backend/internal/dto/validations_test.go +++ b/backend/internal/dto/validations_test.go @@ -57,3 +57,28 @@ func TestValidateClientID(t *testing.T) { }) } } + +func TestValidateCallbackURL(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"valid https URL", "https://example.com/callback", true}, + {"valid loopback URL", "http://127.0.0.1:49813/callback", true}, + {"empty scheme", "//127.0.0.1:49813/callback", true}, + {"valid custom scheme", "pocketid://callback", true}, + {"invalid malformed URL", "http://[::1", false}, + {"invalid missing scheme separator", "://example.com/callback", false}, + {"rejects javascript scheme", "javascript:alert(1)", false}, + {"rejects mixed case javascript scheme", "JavaScript:alert(1)", false}, + {"rejects data scheme", "data:text/html;base64,PGgxPkhlbGxvPC9oMT4=", false}, + {"rejects mixed case data scheme", "DaTa:text/html;base64,PGgxPkhlbGxvPC9oMT4=", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, ValidateCallbackURL(tt.input)) + }) + } +} diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 9d0b7ff0..4f295a6e 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -125,9 +125,7 @@ func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) { func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) { tx := s.db.Begin() - defer func() { - tx.Rollback() - }() + defer tx.Rollback() var client model.OidcClient err := tx. diff --git a/frontend/src/routes/authorize/+page.svelte b/frontend/src/routes/authorize/+page.svelte index 420581e7..e629d6e8 100644 --- a/frontend/src/routes/authorize/+page.svelte +++ b/frontend/src/routes/authorize/+page.svelte @@ -71,19 +71,16 @@ reauthToken = await webauthnService.reauthenticate(authResponse); } - await oidService - .authorize( - client!.id, - scope, - callbackURL, - nonce, - codeChallenge, - codeChallengeMethod, - reauthToken - ) - .then(async ({ code, callbackURL, issuer }) => { - onSuccess(code, callbackURL, issuer); - }); + const authResult = await oidService.authorize( + client!.id, + scope, + callbackURL, + nonce, + codeChallenge, + codeChallengeMethod, + reauthToken + ); + onSuccess(authResult.code, authResult.callbackURL, authResult.issuer); } catch (e) { errorMessage = getWebauthnErrorMessage(e); isLoading = false; @@ -91,13 +88,17 @@ } function onSuccess(code: string, callbackURL: string, issuer: string) { + const redirectURL = new URL(callbackURL); + if (redirectURL.protocol == 'javascript:' || redirectURL.protocol == 'data:') { + throw new Error('Invalid redirect URL protocol'); + } + + redirectURL.searchParams.append('code', code); + redirectURL.searchParams.append('state', authorizeState); + redirectURL.searchParams.append('iss', issuer); + success = true; setTimeout(() => { - const redirectURL = new URL(callbackURL); - redirectURL.searchParams.append('code', code); - redirectURL.searchParams.append('state', authorizeState); - redirectURL.searchParams.append('iss', issuer); - window.location.href = redirectURL.toString(); }, 1000); }