Unverified Commit 0f12ceba authored by Nicola Soranzo's avatar Nicola Soranzo Committed by GitHub
Browse files

Merge pull request #19485 from nsoranzo/type_annot

Type annotation improvements
parents 2b4772d2 7b9b4965
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -1249,6 +1249,6 @@ def get_jobs_to_check_at_startup(session: galaxy_scoped_session, track_jobs_in_d
    return session.scalars(stmt).all()


def get_job(session, *where_clauses):
def get_job(session: galaxy_scoped_session, *where_clauses):
    stmt = select(Job).where(*where_clauses).limit(1)
    return session.scalars(stmt).first()
+4 −1
Original line number Diff line number Diff line
import logging
from typing import (
    Any,
    Dict,
    Optional,
    TYPE_CHECKING,
    Union,
@@ -28,6 +30,7 @@ log = logging.getLogger(__name__)

if TYPE_CHECKING:
    from galaxy.managers.base import OrmFilterParsersType
    from galaxy.managers.context import ProvidesUserContext


class DynamicToolManager(ModelManager[model.DynamicTool]):
@@ -47,7 +50,7 @@ class DynamicToolManager(ModelManager[model.DynamicTool]):
        stmt = select(DynamicTool).where(DynamicTool.id == object_id)
        return self.session().scalars(stmt).one_or_none()

    def create_tool(self, trans, tool_payload, allow_load=True):
    def create_tool(self, trans: "ProvidesUserContext", tool_payload: Dict[str, Any], allow_load: bool = True):
        if not getattr(self.app.config, "enable_beta_tool_formats", False):
            raise exceptions.ConfigDoesNotAllowException(
                "Set 'enable_beta_tool_formats' in Galaxy config to create dynamic tools."
+1 −1
Original line number Diff line number Diff line
@@ -1371,7 +1371,7 @@ class DynamicTool(Base, Dictifiable, RepresentById):
    tool_directory: Mapped[Optional[str]] = mapped_column(Unicode(255))
    hidden: Mapped[Optional[bool]] = mapped_column(default=True)
    active: Mapped[Optional[bool]] = mapped_column(default=True)
    value: Mapped[Optional[bytes]] = mapped_column(MutableJSONType)
    value: Mapped[Optional[Dict[str, Any]]] = mapped_column(MutableJSONType)

    dict_collection_visible_keys = ("id", "tool_id", "tool_format", "tool_version", "uuid", "active", "hidden")
    dict_element_visible_keys = ("id", "tool_id", "tool_format", "tool_version", "uuid", "active", "hidden")
+6 −0
Original line number Diff line number Diff line
@@ -64,10 +64,14 @@ except ImportError:

try:
    from cwltool.utils import (
        CWLObjectType,
        JobsType,
        normalizeFilesDirs,
        visit_class,
    )
except ImportError:
    CWLObjectType = object  # type: ignore[assignment, misc]
    JobsType = object  # type: ignore[misc, unused-ignore]
    visit_class = None  # type: ignore[assignment]
    normalizeFilesDirs = None  # type: ignore[assignment]

@@ -115,9 +119,11 @@ def ensure_cwltool_available():

__all__ = (
    "CommentedMap",
    "CWLObjectType",
    "default_loader",
    "ensure_cwltool_available",
    "getdefault",
    "JobsType",
    "load_tool",
    "LoadingContext",
    "main",
+65 −45
Original line number Diff line number Diff line
@@ -13,13 +13,18 @@ from abc import (
    abstractmethod,
)
from typing import (
    Any,
    Dict,
    List,
    Optional,
    overload,
    TYPE_CHECKING,
    Union,
)
from uuid import uuid4
from uuid import (
    UUID,
    uuid4,
)

from typing_extensions import (
    Literal,
@@ -61,6 +66,14 @@ from .schema import (
)
from .util import SECONDARY_FILES_EXTRA_PREFIX

if TYPE_CHECKING:
    from .cwltool_deps import (
        CWLObjectType,
        JobsType,
        workflow,
    )
    from .schema import RawProcessReference

log = logging.getLogger(__name__)

JOB_JSON_FILE = ".cwl_job.json"
@@ -109,18 +122,25 @@ class InputInstanceArrayDict(TypedDict):
class ToolProxy(metaclass=ABCMeta):
    _class: str

    def __init__(self, tool, uuid, raw_process_reference=None, tool_path=None):
    def __init__(
        self,
        tool: process.Process,
        uuid: Union[UUID, str],
        raw_process_reference: Optional["RawProcessReference"] = None,
        tool_path: Optional[str] = None,
    ):
        self._tool = tool
        self.uuid = uuid
        self._tool_path = tool_path
        self._raw_process_reference = raw_process_reference
        # remove input parameter formats from CWL files so that cwltool
        # does not complain they are missing in the input data
        assert isinstance(self._tool.inputs_record_schema, dict)
        for input_field in self._tool.inputs_record_schema["fields"]:
            if "format" in input_field:
                del input_field["format"]

    def job_proxy(self, input_dict, output_dict, job_directory="."):
    def job_proxy(self, input_dict: Dict[str, Any], output_dict, job_directory: str = "."):
        """Build a cwltool.job.Job describing computation using a input_json
        Galaxy will generate mapping the Galaxy description of the inputs into
        a cwltool compatible variant.
@@ -154,7 +174,7 @@ class ToolProxy(metaclass=ABCMeta):
        """Return InputInstance objects describing mapping to Galaxy inputs."""

    @abstractmethod
    def output_instances(self):
    def output_instances(self) -> List["OutputInstance"]:
        """Return OutputInstance objects describing mapping to Galaxy inputs."""

    @abstractmethod
@@ -185,7 +205,9 @@ class ToolProxy(metaclass=ABCMeta):
        }

    @staticmethod
    def from_persistent_representation(as_object, strict_cwl_validation=True, tool_directory=None) -> "ToolProxy":
    def from_persistent_representation(
        as_object: Dict[str, Any], strict_cwl_validation: bool = True, tool_directory: Optional[str] = None
    ) -> "ToolProxy":
        """Recover an object serialized with to_persistent_representation."""
        if "class" not in as_object:
            raise Exception("Failed to deserialize tool proxy from JSON object - no class found.")
@@ -276,11 +298,7 @@ class CommandLineToolProxy(ToolProxy):
        if outputs_schema["type"] != "record":
            raise Exception("Unhandled CWL tool output structure")

        rval = []
        for output in outputs_schema["fields"]:
            rval.append(_simple_field_to_output(output))

        return rval
        return [_simple_field_to_output(output) for output in outputs_schema["fields"]]

    def docker_identifier(self):
        for hint in self.hints_or_requirements_of_class("DockerRequirement"):
@@ -297,17 +315,18 @@ class ExpressionToolProxy(CommandLineToolProxy):


class JobProxy:
    def __init__(self, tool_proxy, input_dict, output_dict, job_directory):
    _is_command_line_job: bool

    def __init__(self, tool_proxy: ToolProxy, input_dict: Dict[str, Any], output_dict, job_directory: str):
        assert RuntimeContext is not None, "cwltool is not installed, cannot run CWL jobs"
        self._tool_proxy = tool_proxy
        self._input_dict = input_dict
        self._output_dict = output_dict
        self._job_directory = job_directory

        self._final_output = None
        self._final_output: Optional[CWLObjectType] = None
        self._ok = True
        self._cwl_job = None
        self._is_command_line_job = None
        self._cwl_job: Optional[JobsType] = None

        self._normalize_job()

@@ -318,7 +337,6 @@ class JobProxy:
    @property
    def is_command_line_job(self):
        self._ensure_cwl_job_initialized()
        assert self._is_command_line_job is not None
        return self._is_command_line_job

    def _ensure_cwl_job_initialized(self):
@@ -333,9 +351,8 @@ class JobProxy:
                beta_relaxed_fmt_check=beta_relaxed_fmt_check,
            )

            args = [RuntimeContext(job_args)]
            kwargs: Dict[str, str] = {}
            self._cwl_job = next(self._tool_proxy._tool.job(self._input_dict, self._output_callback, *args, **kwargs))
            runtimeContext = RuntimeContext(job_args)
            self._cwl_job = next(self._tool_proxy._tool.job(self._input_dict, self._output_callback, runtimeContext))
            self._is_command_line_job = hasattr(self._cwl_job, "command_line")

    def _normalize_job(self):
@@ -443,7 +460,7 @@ class JobProxy:
        else:
            return {}

    def _output_callback(self, out, process_status):
    def _output_callback(self, out: Optional["CWLObjectType"], process_status: str):
        self._process_status = process_status
        if process_status == "success":
            self._final_output = out
@@ -452,7 +469,7 @@ class JobProxy:

        log.info(f"Output are {out}, status is {process_status}")

    def collect_outputs(self, tool_working_directory, rcode):
    def collect_outputs(self, tool_working_directory: str, rcode: int):
        if not self.is_command_line_job:
            cwl_job = self.cwl_job()
            if RuntimeContext is not None:
@@ -465,8 +482,8 @@ class JobProxy:
        else:
            return self.cwl_job().collect_outputs(tool_working_directory, rcode)

    def save_job(self):
        job_file = JobProxy._job_file(self._job_directory)
    def save_job(self) -> None:
        job_file = self._job_file(self._job_directory)
        job_objects = {
            # "tool_path": os.path.abspath(self._tool_proxy._tool_path),
            "tool_representation": self._tool_proxy.to_persistent_representation(),
@@ -529,7 +546,7 @@ class JobProxy:


class WorkflowProxy:
    def __init__(self, workflow, workflow_path=None):
    def __init__(self, workflow: "workflow.Workflow", workflow_path: Optional[str] = None):
        self._workflow = workflow
        self._workflow_path = workflow_path
        self._step_proxies: Optional[List[Union[SubworkflowStepProxy, ToolStepProxy]]] = None
@@ -719,7 +736,11 @@ class WorkflowProxy:


def tool_proxy(
    tool_path=None, tool_object=None, strict_cwl_validation=True, tool_directory=None, uuid=None
    tool_path: Optional[str] = None,
    tool_object=None,
    strict_cwl_validation: bool = True,
    tool_directory: Optional[str] = None,
    uuid: Optional[Union[UUID, str]] = None,
) -> ToolProxy:
    """Provide a proxy object to cwltool data structures to just
    grab relevant data.
@@ -735,7 +756,7 @@ def tool_proxy(


def tool_proxy_from_persistent_representation(
    persisted_tool, strict_cwl_validation=True, tool_directory=None
    persisted_tool: Dict[str, Any], strict_cwl_validation: bool = True, tool_directory: Optional[str] = None
) -> ToolProxy:
    """Load a ToolProxy from a previously persisted representation."""
    ensure_cwltool_available()
@@ -744,12 +765,12 @@ def tool_proxy_from_persistent_representation(
    )


def workflow_proxy(workflow_path, strict_cwl_validation=True) -> WorkflowProxy:
def workflow_proxy(workflow_path: str, strict_cwl_validation: bool = True) -> WorkflowProxy:
    ensure_cwltool_available()
    return _to_cwl_workflow_object(workflow_path, strict_cwl_validation=strict_cwl_validation)


def load_job_proxy(job_directory, strict_cwl_validation=True) -> JobProxy:
def load_job_proxy(job_directory: str, strict_cwl_validation: bool = True) -> JobProxy:
    ensure_cwltool_available()
    job_objects_path = os.path.join(job_directory, JOB_JSON_FILE)
    job_objects = json.load(open(job_objects_path))
@@ -763,19 +784,16 @@ def load_job_proxy(job_directory, strict_cwl_validation=True) -> JobProxy:


def _to_cwl_tool_object(
    tool_path=None,
    tool_path: Optional[str] = None,
    tool_object=None,
    cwl_tool_object=None,
    raw_process_reference=None,
    strict_cwl_validation=False,
    tool_directory=None,
    uuid=None,
    strict_cwl_validation: bool = False,
    tool_directory: Optional[str] = None,
    uuid: Optional[Union[UUID, str]] = None,
) -> ToolProxy:
    if uuid is None:
        uuid = str(uuid4())
    schema_loader = _schema_loader(strict_cwl_validation)
    if raw_process_reference is None and tool_path is not None:
        assert cwl_tool_object is None
    if tool_path is not None:
        assert tool_object is None

        raw_process_reference = schema_loader.raw_process_reference(tool_path)
@@ -783,9 +801,6 @@ def _to_cwl_tool_object(
            raw_process_reference=raw_process_reference,
        )
    elif tool_object is not None:
        assert raw_process_reference is None
        assert cwl_tool_object is None

        # Allow loading tools from YAML...
        as_str = json.dumps(tool_object)
        tool_object = yaml_no_ts().load(as_str)
@@ -799,7 +814,7 @@ def _to_cwl_tool_object(
            raw_process_reference=raw_process_reference,
        )
    else:
        cwl_tool = cwl_tool_object
        raise ValueError("Either tool_path or tool_object should be defined")

    if isinstance(cwl_tool, int):
        raise Exception("Failed to load tool.")
@@ -812,7 +827,12 @@ def _to_cwl_tool_object(
    return _cwl_tool_object_to_proxy(cwl_tool, uuid, raw_process_reference=raw_process_reference, tool_path=tool_path)


def _cwl_tool_object_to_proxy(cwl_tool, uuid, raw_process_reference=None, tool_path=None) -> ToolProxy:
def _cwl_tool_object_to_proxy(
    cwl_tool: process.Process,
    uuid: Union[UUID, str],
    raw_process_reference: Optional["RawProcessReference"] = None,
    tool_path: Optional[str] = None,
) -> ToolProxy:
    raw_tool = cwl_tool.tool
    if "class" not in raw_tool:
        raise Exception("File does not declare a class, not a valid Draft 3+ CWL tool.")
@@ -831,14 +851,14 @@ def _cwl_tool_object_to_proxy(cwl_tool, uuid, raw_process_reference=None, tool_p
    return proxy_class(cwl_tool, uuid, raw_process_reference, tool_path)


def _to_cwl_workflow_object(workflow_path, strict_cwl_validation=None) -> WorkflowProxy:
def _to_cwl_workflow_object(workflow_path: str, strict_cwl_validation: bool = True) -> WorkflowProxy:
    cwl_workflow = _schema_loader(strict_cwl_validation).tool(path=workflow_path)
    raw_workflow = cwl_workflow.tool
    check_requirements(raw_workflow, tool=False)
    return WorkflowProxy(cwl_workflow, workflow_path)


def _schema_loader(strict_cwl_validation):
def _schema_loader(strict_cwl_validation: bool):
    return schema_loader if strict_cwl_validation else non_strict_non_validating_schema_loader


@@ -912,7 +932,7 @@ def split_step_references(step_references, workflow_id=None, multiple=True):
        return split_references[0]


def build_step_proxy(workflow_proxy: WorkflowProxy, step, index):
def build_step_proxy(workflow_proxy: WorkflowProxy, step: "workflow.WorkflowStep", index: int):
    step_type = step.embedded_tool.tool["class"]
    if step_type == "Workflow":
        return SubworkflowStepProxy(workflow_proxy, step, index)
@@ -969,7 +989,7 @@ class InputProxy:


class BaseStepProxy:
    def __init__(self, workflow_proxy: WorkflowProxy, step, index):
    def __init__(self, workflow_proxy: WorkflowProxy, step: "workflow.WorkflowStep", index: int):
        self._workflow_proxy = workflow_proxy
        self._step = step
        self._index = index
@@ -1271,7 +1291,7 @@ OUTPUT_TYPE = Bunch(

# TODO: Different subclasses - this is representing different types of things.
class OutputInstance:
    def __init__(self, name, output_data_type, output_type, path=None, fields=None):
    def __init__(self, name: str, output_data_type, output_type, path=None, fields=None):
        self.name = name
        self.output_data_type = output_data_type
        self.output_type = output_type
Loading