diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index fbce6f81..4eb4f99d 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -1260,7 +1260,10 @@ func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID stri return "", err } - codeChallengeMethodSha256 := strings.ToUpper(codeChallengeMethod) == "S256" + codeChallengeMethodSha256, err := codeChallengeMethodIsSha256(codeChallengeMethod) + if err != nil { + return "", err + } oidcAuthorizationCode := model.OidcAuthorizationCode{ ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)), @@ -1285,6 +1288,19 @@ func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID stri return randomString, nil } +func codeChallengeMethodIsSha256(codeChallengeMethod string) (bool, error) { + switch strings.ToUpper(codeChallengeMethod) { + case "": + return false, nil + case "PLAIN": + return false, nil + case "S256": + return true, nil + default: + return false, common.NewOidcInvalidRequestError("code challenge method not supported") + } +} + func validateCodeVerifier(codeVerifier, codeChallenge string, codeChallengeMethodSha256 bool) bool { if codeVerifier == "" || codeChallenge == "" { return false diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index bd5e663d..d5d8ad45 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -763,6 +763,61 @@ func TestValidateCodeVerifier_Plain(t *testing.T) { }) } +func TestCodeChallengeMethodIsSha256(t *testing.T) { + tests := []struct { + name string + method string + wantSha256 bool + wantErr bool + }{ + { + name: "omitted defaults to plain", + method: "", + wantSha256: false, + }, + { + name: "plain", + method: "plain", + wantSha256: false, + }, + { + name: "plain case insensitive", + method: "PLAIN", + wantSha256: false, + }, + { + name: "s256", + method: "S256", + wantSha256: true, + }, + { + name: "s256 case insensitive", + method: "s256", + wantSha256: true, + }, + { + name: "unknown method", + method: "S384", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := codeChallengeMethodIsSha256(tt.method) + if tt.wantErr { + require.Error(t, err) + var invalidRequest *common.OidcInvalidRequestError + require.ErrorAs(t, err, &invalidRequest) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantSha256, got) + }) + } +} + func TestOidcService_updateClientLogoType(t *testing.T) { // Create a test database db := testutils.NewDatabaseForTest(t)