Unverified Commit 47628efd authored by Marius van den Beek's avatar Marius van den Beek Committed by GitHub
Browse files

Merge pull request #16632 from mvdbeek/conditional_fixes

[23.1] Fixes for conditional subworkflow steps
parents 6bb9bc69 735f9186
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -94,12 +94,19 @@ class Tree(BaseTree):
            def get_element(collection):
                return collection[index]  # noqa: B023

            when_value = None
            if self.when_values:
                if len(self.when_values) == 1:
                    when_value = self.when_values[0]
                else:
                    when_value = self.when_values[index]

            if substructure.is_leaf:
                yield dict_map(get_element, collection_dict), self.when_values[index] if self.when_values else None
                yield dict_map(get_element, collection_dict), when_value
            else:
                sub_collections = dict_map(lambda collection: get_element(collection).child_collection, collection_dict)
                for element, _when_value in substructure._walk_collections(sub_collections):
                    yield element, self.when_values[index] if self.when_values else None
                    yield element, when_value

    @property
    def is_leaf(self):
+16 −2
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ class ModelOperationToolAction(DefaultToolAction):
        execution_cache=None,
        collection_info=None,
        job_callback=None,
        skip=False,
        **kwargs,
    ):
        incoming = incoming or {}
@@ -93,11 +94,15 @@ class ModelOperationToolAction(DefaultToolAction):
            history=history,
            tags=preserved_tags,
            hdca_tags=preserved_hdca_tags,
            skip=skip,
        )
        self._record_inputs(trans, tool, job, incoming, inp_data, inp_dataset_collections)
        self._record_outputs(job, out_data, output_collections)
        if job_callback:
            job_callback(job)
        if skip:
            job.state = job.states.SKIPPED
        else:
            job.state = job.states.OK
        trans.sa_session.add(job)

@@ -108,7 +113,7 @@ class ModelOperationToolAction(DefaultToolAction):
        return job, out_data, history

    def _produce_outputs(
        self, trans: "ProvidesUserContext", tool, out_data, output_collections, incoming, history, tags, hdca_tags
        self, trans: "ProvidesUserContext", tool, out_data, output_collections, incoming, history, tags, hdca_tags, skip
    ):
        tag_handler = trans.tag_handler
        tool.produce_outputs(
@@ -128,4 +133,13 @@ class ModelOperationToolAction(DefaultToolAction):
                    value.visible = False
                    mapped_over_elements[name].hda = value

        # We probably need to mark all outputs as skipped, not just the outputs of whatever the database op tools do ?
        # This is probably not exactly right, but it might also work in most cases
        if skip:
            for output_collection in output_collections.out_collections.values():
                output_collection.mark_as_populated()
            for hdca in output_collections.out_collection_instances.values():
                hdca.visible = False
        # Would we also need to replace the datasets with skipped datasets?

        trans.sa_session.add_all(out_data.values())
+16 −2
Original line number Diff line number Diff line
@@ -512,8 +512,15 @@ class WorkflowModule:
        collections_to_match = self._find_collections_to_match(progress, step, all_inputs)
        # Have implicit collections...
        collection_info = self.trans.app.dataset_collection_manager.match_collections(collections_to_match)
        if collection_info and progress.subworkflow_collection_info:
        if collection_info:
            if progress.subworkflow_collection_info:
                # We've mapped over a subworkflow. Slices of the invocation might be conditional
                # and progress.subworkflow_collection_info.when_values holds the appropriate when_values
                collection_info.when_values = progress.subworkflow_collection_info.when_values
            else:
                # The invocation is not mapped over, but it might still be conditional.
                # Multiplication and linking should be handled by slice_collection()
                collection_info.when_values = progress.when_values
        return collection_info or progress.subworkflow_collection_info

    def _find_collections_to_match(self, progress, step, all_inputs):
@@ -2295,6 +2302,13 @@ class ToolModule(WorkflowModule):
            self._handle_mapped_over_post_job_actions(
                step, step_inputs, step_outputs, progress.effective_replacement_dict()
            )
            if progress.when_values == [False] and not progress.subworkflow_collection_info:
                # Step skipped entirely. We hide the output to avoid confusion.
                # Could be revisited if we have a nice visual way to say these are skipped ?
                for output in step_outputs.values():
                    if isinstance(output, (model.HistoryDatasetAssociation, model.HistoryDatasetCollectionAssociation)):
                        output.visible = False

        if execution_tracker.execution_errors:
            # TODO: formalize into InvocationFailure ?
            message = f"Failed to create {len(execution_tracker.execution_errors)} job(s) for workflow step {step.order_index + 1}: {str(execution_tracker.execution_errors[0])}"
+60 −0
Original line number Diff line number Diff line
@@ -2042,6 +2042,66 @@ should_run:
            assert message["details"] == "Type is: str"
            assert message["workflow_step_id"] == 2

    def test_run_workflow_subworkflow_conditional_with_simple_mapping_step(self):
        with self.dataset_populator.test_history() as history_id:
            summary = self._run_workflow(
                """class: GalaxyWorkflow
inputs:
  should_run:
    type: boolean
  some_collection:
    type: data_collection
steps:
  subworkflow:
    run:
      class: GalaxyWorkflow
      inputs:
        some_collection:
          type: data_collection
        should_run:
          type: boolean
      steps:
        a_tool_step:
          tool_id: cat1
          in:
            input1: some_collection
    in:
      some_collection: some_collection
      should_run: should_run
    outputs:
      inner_out: a_tool_step/out_file1
    when: $(inputs.should_run)
outputs:
  outer_output:
    outputSource: subworkflow/inner_out
""",
                test_data="""
some_collection:
  collection_type: list
  elements:
    - identifier: true
      content: A
    - identifier: false
      content: B
  type: File
should_run:
  value: false
  type: raw
""",
                history_id=history_id,
                wait=True,
                assert_ok=True,
            )
            invocation_details = self.workflow_populator.get_invocation(summary.invocation_id, step_details=True)
            subworkflow_invocation_id = invocation_details["steps"][-1]["subworkflow_invocation_id"]
            self.workflow_populator.wait_for_invocation_and_jobs(
                history_id=history_id, workflow_id="whatever", invocation_id=subworkflow_invocation_id
            )
            invocation_details = self.workflow_populator.get_invocation(subworkflow_invocation_id, step_details=True)
            for step in invocation_details["steps"]:
                if step["workflow_step_label"] == "a_tool_step":
                    assert sum(1 for j in step["jobs"] if j["state"] == "skipped") == 2

    def test_run_workflow_subworkflow_conditional_step(self):
        with self.dataset_populator.test_history() as history_id:
            summary = self._run_workflow(