Unverified Commit a0fc6a2e authored by Marius van den Beek's avatar Marius van den Beek Committed by GitHub
Browse files

Merge pull request #19730 from mvdbeek/fix_tool_state_correction_refactor_action

[24.2] Allow tool state changes in refactor actions
parents c5d214d5 516f99ff
Loading
Loading
Loading
Loading
+33 −4
Original line number Diff line number Diff line
@@ -5,7 +5,8 @@ import { RouterLink } from "vue-router";
import { useRouter } from "vue-router/composables";

import { canMutateHistory } from "@/api";
import { copyWorkflow } from "@/components/Workflow/workflows.services";
import { copyWorkflow, getWorkflowInfo } from "@/components/Workflow/workflows.services";
import { useWorkflowInstance } from "@/composables/useWorkflowInstance";
import { useHistoryItemsStore } from "@/stores/historyItemsStore";
import { useHistoryStore } from "@/stores/historyStore";
import { useUserStore } from "@/stores/userStore";
@@ -52,13 +53,14 @@ const submissionError = ref("");
const workflowError = ref("");
const workflowName = ref("");
const workflowModel: any = ref(null);
const owner = ref<string>();

const currentHistoryId = computed(() => historyStore.currentHistoryId);
const editorLink = computed(
    () => `/workflows/edit?id=${props.workflowId}${props.version ? `&version=${props.version}` : ""}`
);
const historyStatusKey = computed(() => `${currentHistoryId.value}_${lastUpdateTime.value}`);
const isOwner = computed(() => userStore.matchesCurrentUsername(workflowModel.value.runData.owner));
const isOwner = computed(() => userStore.matchesCurrentUsername(owner.value));
const lastUpdateTime = computed(() => historyItemsStore.lastUpdateTime);
const canRunOnHistory = computed(() => {
    if (!currentHistoryId.value) {
@@ -68,6 +70,16 @@ const canRunOnHistory = computed(() => {
    return (history && canMutateHistory(history)) ?? false;
});

if (props.instance) {
    const { workflow } = useWorkflowInstance(props.workflowId);
    watch(workflow, () => {
        if (workflow.value) {
            workflowName.value = workflow.value?.name;
            owner.value = workflow.value?.owner;
        }
    });
}

function handleInvocations(incomingInvocations: any) {
    invocations.value = incomingInvocations;
    // make sure any new histories are added to historyStore
@@ -116,14 +128,31 @@ async function loadRun() {
        hasStepVersionChanges.value = incomingModel.hasStepVersionChanges;
        workflowName.value = incomingModel.name;
        workflowModel.value = incomingModel;
        owner.value = incomingModel.runData.owner;
        loading.value = false;
    } catch (e) {
        workflowError.value = errorMessageAsString(e);
        const errMessage = errorMessageAsString(e);
        if (errMessage === "Workflow step has upgrade messages") {
            hasUpgradeMessages.value = true;
            if (!props.instance) {
                try {
                    const storedWorkflow = await getWorkflowInfo(props.workflowId);
                    owner.value = storedWorkflow.owner;
                    workflowName.value = storedWorkflow.name;
                } catch {
                    // just show original error
                    workflowError.value = errMessage;
                }
            }
        } else {
            workflowError.value = errMessage;
        }
        loading.value = false;
    }
}

async function onImport() {
    const response = await copyWorkflow(props.workflowId, workflowModel.value.runData.owner, props.version);
    const response = await copyWorkflow(props.workflowId, owner.value, props.version);
    router.push(`/workflows/edit?id=${response.id}`);
}

+1 −1
Original line number Diff line number Diff line
@@ -56,7 +56,7 @@ export async function updateWorkflow(id: string, changes: object): Promise<Workf
    return data;
}

export async function copyWorkflow(id: string, currentOwner: string, version?: string): Promise<Workflow> {
export async function copyWorkflow(id: string, currentOwner?: string, version?: string): Promise<Workflow> {
    let path = `/api/workflows/${id}/download`;
    if (version) {
        path += `?version=${version}`;
+1 −1
Original line number Diff line number Diff line
@@ -1983,7 +1983,7 @@ class WorkflowContentsManager(UsesAnnotations):
            dry_run=refactor_request.dry_run,
        )

        module_injector = WorkflowModuleInjector(trans)
        module_injector = WorkflowModuleInjector(trans, allow_tool_state_corrections=True)
        refactor_executor = WorkflowRefactorExecutor(raw_workflow_description, workflow, module_injector)
        action_executions = refactor_executor.refactor(refactor_request)
        refactored_workflow, errors = self.update_workflow_from_raw_description(
+15 −5
Original line number Diff line number Diff line
@@ -322,8 +322,13 @@ class WorkflowModule:
        the step.
        """
        if inputs := self.get_inputs():
            try:
                return self.state.encode(Bunch(inputs=inputs), self.trans.app, nested=nested)
        else:
            except ValueError:
                log.warning("Tool state invalid for workflow module", exc_info=True)
                # Always preferable to save unmodified and continue, I think ... we're explicit about alterations when retrieving any workflow,
                # and it is what we do if we don't have the tool installed (assuming this is a tool).
                pass
        return self.state.inputs

    def get_export_state(self):
@@ -2613,7 +2618,7 @@ class WorkflowModuleInjector:
        self.trans = trans
        self.allow_tool_state_corrections = allow_tool_state_corrections

    def inject(self, step: WorkflowStep, step_args=None, steps=None, **kwargs):
    def inject(self, step: WorkflowStep, step_args=None, steps=None, allow_tool_state_corrections=False, **kwargs):
        """Pre-condition: `step` is an ORM object coming from the database, if
        supplied `step_args` is the representation of the inputs for that step
        supplied via web form.
@@ -2646,7 +2651,12 @@ class WorkflowModuleInjector:

            subworkflow = step.subworkflow
            assert subworkflow
            populate_module_and_state(self.trans, subworkflow, param_map=unjsonified_subworkflow_param_map)
            populate_module_and_state(
                self.trans,
                subworkflow,
                param_map=unjsonified_subworkflow_param_map,
                allow_tool_state_corrections=allow_tool_state_corrections,
            )

    def inject_all(self, workflow: Workflow, param_map=None, ignore_tool_missing_exception=False, **kwargs):
        param_map = param_map or {}
+31 −12
Original line number Diff line number Diff line
@@ -42,19 +42,26 @@ from .schema import (
    UpgradeSubworkflowAction,
    UpgradeToolAction,
)
from ..modules import InputParameterModule
from ..modules import (
    InputParameterModule,
    WorkflowModuleInjector,
)

log = logging.getLogger(__name__)


class WorkflowRefactorExecutor:
    def __init__(self, raw_workflow_description, workflow, module_injector):
    def __init__(self, raw_workflow_description, workflow, module_injector: WorkflowModuleInjector):
        # we mostly use the ga representation, but there may be cases where the
        # models/modules of existing workflow are more usable.
        self.raw_workflow_description = raw_workflow_description
        self.workflow = workflow
        self.module_injector = module_injector
        self.module_injector.inject_all(workflow, ignore_tool_missing_exception=True)
        self.module_injector.inject_all(
            workflow,
            ignore_tool_missing_exception=True,
            allow_tool_state_corrections=module_injector.allow_tool_state_corrections,
        )

    def refactor(self, refactor_request: RefactorActions):
        action_executions = []
@@ -460,10 +467,22 @@ class WorkflowRefactorExecutor:
    def _inject(self, step, execution):
        # compute runtime state, capture upgrade messages that result
        if not hasattr(step, "module"):
            self.module_injector.inject(step)
            self.module_injector.inject(step, allow_tool_state_corrections=True)
        self.module_injector.compute_runtime_state(step)
        if getattr(step, "upgrade_messages", None):
            for key, value in step.upgrade_messages.items():
                if isinstance(value, dict):
                    for input_name, message in value.items():
                        execution.messages.append(
                            RefactorActionExecutionMessage(
                                message=message,
                                message_type=RefactorActionExecutionMessageTypeEnum.tool_state_adjustment,
                                input_name=input_name,
                                step_label=step.label,
                                order_index=step.order_index,
                            )
                        )
                else:
                    message = RefactorActionExecutionMessage(
                        message=value,
                        message_type=RefactorActionExecutionMessageTypeEnum.tool_state_adjustment,
Loading