Unverified Commit 7f569cdb authored by Nicola Soranzo's avatar Nicola Soranzo
Browse files

Type annotation fixes for mypy 1.18.1

parent 060f48e1
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -10779,7 +10779,7 @@ class UserAuthnzToken(Base, UserMixin, RepresentById):
    user_id: Mapped[Optional[int]] = mapped_column(ForeignKey("galaxy_user.id"), index=True)
    uid: Mapped[Optional[str]] = mapped_column(VARCHAR(255))  # type:ignore[assignment]
    provider: Mapped[Optional[str]] = mapped_column(VARCHAR(32))  # type:ignore[assignment]
    extra_data: Mapped[Optional[dict[str, Any]]] = mapped_column(MutableJSONType)
    extra_data: Mapped[Optional[dict[str, Any]]] = mapped_column(MutableJSONType)  # type:ignore[assignment]
    lifetime: Mapped[Optional[int]]
    assoc_type: Mapped[Optional[str]] = mapped_column(VARCHAR(64))
    user: Mapped[Optional["User"]] = relationship(back_populates="social_auth")
+2 −2
Original line number Diff line number Diff line
@@ -85,10 +85,10 @@ class PatchGenericPickle:

        if not issubclass(cls, BaseModel):
            raise TypeError("PatchGenericPickle can only be used with subclasses of pydantic.BaseModel")
        if not issubclass(cls, Generic):  # type: ignore [arg-type]
        if not issubclass(cls, Generic):  # type: ignore[unreachable]  # https://github.com/python/mypy/issues/19377
            raise TypeError("PatchGenericPickle can only be used with Generic models")

        qualname = cls.__qualname__
        qualname = cls.__qualname__  # type: ignore[unreachable]  # https://github.com/python/mypy/issues/19377
        declaring_module = sys.modules[cls.__module__]
        if qualname not in declaring_module.__dict__:
            # This should work in all cases, but we might need to make this check and update more
+13 −12
Original line number Diff line number Diff line
@@ -150,6 +150,7 @@ from galaxy.tools.parameters.basic import (
    DataCollectionToolParameter,
    DataToolParameter,
    HiddenToolParameter,
    ParameterValueError,
    SelectTagParameter,
    SelectToolParameter,
    ToolParameter,
@@ -1497,7 +1498,7 @@ class Tool(UsesDictVisibleKeys, ToolParameterBundle):
        self._is_workflow_compatible = self.check_workflow_compatible(self.tool_source)

    def __parse_legacy_features(self, tool_source: ToolSource):
        self.code_namespace: dict[str, str] = {}
        self.code_namespace: dict[str, Any] = {}
        self.hook_map: dict[str, str] = {}
        self.uihints: dict[str, str] = {}

@@ -2064,7 +2065,7 @@ class Tool(UsesDictVisibleKeys, ToolParameterBundle):
        self, request_context: WorkRequestContext, incoming: ToolRequestT, input_format: InputFormatT = "legacy"
    ) -> tuple[
        list[ToolStateJobInstancePopulatedT],
        list[ToolStateJobInstancePopulatedT],
        list[ParameterValidationErrorsT],
        Optional[int],
        Optional[MatchingCollections],
    ]:
@@ -2150,7 +2151,7 @@ class Tool(UsesDictVisibleKeys, ToolParameterBundle):

    def _handle_validate_input_hook(
        self, request_context, params: ToolStateJobInstancePopulatedT, errors: ParameterValidationErrorsT
    ):
    ) -> None:
        # If the tool provides a `validate_input` hook, call it.
        if validate_input := self.get_hook("validate_input"):
            # hooks are so terrible ... this is specifically for https://github.com/galaxyproject/tools-devteam/blob/main/tool_collections/gops/basecoverage/operation_filter.py
@@ -2200,11 +2201,9 @@ class Tool(UsesDictVisibleKeys, ToolParameterBundle):
        there were no errors).
        """
        request_context = proxy_work_context_for_history(trans, history=history)
        expanded = self.expand_incoming(request_context, incoming=incoming, input_format=input_format)
        all_params: list[ToolStateJobInstancePopulatedT] = expanded[0]
        all_errors: list[ParameterValidationErrorsT] = expanded[1]
        rerun_remap_job_id: Optional[int] = expanded[2]
        collection_info: Optional[MatchingCollections] = expanded[3]
        all_params, all_errors, rerun_remap_job_id, collection_info = self.expand_incoming(
            request_context, incoming=incoming, input_format=input_format
        )

        # If there were errors, we stay on the same page and display them
        self.handle_incoming_errors(all_errors)
@@ -2254,7 +2253,7 @@ class Tool(UsesDictVisibleKeys, ToolParameterBundle):
            param_errors = {}
            for d in all_errors:
                for key, value in d.items():
                    if hasattr(value, "to_dict"):
                    if isinstance(value, ParameterValueError):
                        value_obj = value.to_dict()
                    else:
                        value_obj = {"message": unicodify(value)}
@@ -2454,16 +2453,18 @@ class Tool(UsesDictVisibleKeys, ToolParameterBundle):
        param_dict = job.raw_param_dict()
        return self.params_from_strings(param_dict, ignore_errors=ignore_errors)

    def check_and_update_param_values(self, values, trans, update_values=True, workflow_building_mode=False):
    def check_and_update_param_values(
        self, values, trans, update_values: bool = True, workflow_building_mode: bool = False
    ):
        """
        Check that all parameters have values, and fill in with default
        values where necessary. This could be called after loading values
        from a database in case new parameters have been added.
        """
        messages = {}
        messages: dict[str, Any] = {}
        request_context = proxy_work_context_for_history(trans, workflow_building_mode=workflow_building_mode)

        def validate_inputs(input, value, error, parent, context, prefixed_name, prefixed_label, **kwargs):
        def validate_inputs(input, value, error, parent, context, prefixed_name: str, prefixed_label, **kwargs):
            if not error:
                value, error = check_param(request_context, input, value, context)
            if error:
+45 −48
Original line number Diff line number Diff line
@@ -4,8 +4,10 @@ Classes encapsulating Galaxy tool parameters.

from json import dumps
from typing import (
    Any,
    cast,
    Optional,
    Tuple,
    Union,
)

@@ -280,7 +282,9 @@ def visit_input_values(
            )


def check_param(trans, param, incoming_value, param_values, simple_errors=True):
def check_param(
    trans, param: ToolParameter, incoming_value, param_values, simple_errors: bool = True
) -> Tuple[Any, Union[str, ValueError, None]]:
    """
    Check the value of a single parameter `param`. The value in
    `incoming_value` is converted from its HTML encoding and validated.
@@ -289,11 +293,11 @@ def check_param(trans, param, incoming_value, param_values, simple_errors=True):
    when dealing with grouping scenarios).
    """
    value = incoming_value
    error = None
    error: Union[str, ValueError, None] = None
    try:
        if trans.workflow_building_mode:
            if is_runtime_value(value):
                return [runtime_to_json(value), None]
                return runtime_to_json(value), None
        value = param.from_json(value, trans, param_values)
        param.validate(value, trans)
    except ValueError as e:
@@ -448,13 +452,10 @@ def populate_state(
        for input in inputs.values():
            state[input.name] = input.get_initial_value(request_context, context)
            group_state = state[input.name]
            if input.type == "repeat":
                repeat_input = cast(Repeat, input)
                repeat_name = repeat_input.name
            if isinstance(input, Repeat):
                repeat_name = input.name
                repeat_incoming = incoming.get(repeat_name) or []
                if repeat_incoming and (
                    len(repeat_incoming) > repeat_input.max or len(repeat_incoming) < repeat_input.min
                ):
                if repeat_incoming and (len(repeat_incoming) > input.max or len(repeat_incoming) < input.min):
                    errors[repeat_name] = "The number of repeat elements is outside the range specified by the tool."
                else:
                    del group_state[:]
@@ -464,7 +465,7 @@ def populate_state(
                        repeat_errors: ParameterValidationErrorsT = {}
                        populate_state(
                            request_context,
                            repeat_input.inputs,
                            input.inputs,
                            rep,
                            new_state,
                            repeat_errors,
@@ -474,12 +475,12 @@ def populate_state(
                            input_format=input_format,
                        )
                        if repeat_errors:
                            errors[repeat_input.name] = repeat_errors
                            errors[input.name] = repeat_errors

            elif input.type == "conditional":
                conditional_input = cast(Conditional, input)
                test_param = cast(ToolParameter, conditional_input.test_param)
                test_param_value = incoming.get(conditional_input.name, {}).get(test_param.name)
            elif isinstance(input, Conditional):
                test_param = input.test_param
                assert test_param is not None
                test_param_value = incoming.get(input.name, {}).get(test_param.name)
                value, error = (
                    check_param(request_context, test_param, test_param_value, context, simple_errors=simple_errors)
                    if check
@@ -489,15 +490,13 @@ def populate_state(
                    errors[test_param.name] = error
                else:
                    try:
                        current_case = conditional_input.get_current_case(value)
                        group_state = state[conditional_input.name] = {}
                        current_case = input.get_current_case(value)
                        group_state = state[input.name] = {}
                        cast_errors: ParameterValidationErrorsT = {}
                        incoming_for_conditional = cast(
                            ToolStateJobInstanceT, incoming.get(conditional_input.name) or {}
                        )
                        incoming_for_conditional = cast(ToolStateJobInstanceT, incoming.get(input.name) or {})
                        populate_state(
                            request_context,
                            conditional_input.cases[current_case].inputs,
                            input.cases[current_case].inputs,
                            incoming_for_conditional,
                            group_state,
                            cast_errors,
@@ -507,19 +506,18 @@ def populate_state(
                            input_format=input_format,
                        )
                        if cast_errors:
                            errors[conditional_input.name] = cast_errors
                            errors[input.name] = cast_errors
                        group_state["__current_case__"] = current_case
                    except Exception:
                        errors[test_param.name] = "The selected case is unavailable/invalid."
                group_state[test_param.name] = value

            elif input.type == "section":
                section_input = cast(Section, input)
            elif isinstance(input, Section):
                section_errors: ParameterValidationErrorsT = {}
                incoming_for_state = cast(ToolStateJobInstanceT, incoming.get(section_input.name) or {})
                incoming_for_state = cast(ToolStateJobInstanceT, incoming.get(input.name) or {})
                populate_state(
                    request_context,
                    section_input.inputs,
                    input.inputs,
                    incoming_for_state,
                    group_state,
                    section_errors,
@@ -529,12 +527,13 @@ def populate_state(
                    input_format=input_format,
                )
                if section_errors:
                    errors[section_input.name] = section_errors
                    errors[input.name] = section_errors

            elif input.type == "upload_dataset":
                raise NotImplementedError

            else:
                assert isinstance(input, ToolParameter)
                param_value = _get_incoming_value(incoming, input.name, state.get(input.name))
                value, error = (
                    check_param(request_context, input, param_value, context, simple_errors=simple_errors)
@@ -555,7 +554,7 @@ def _populate_state_legacy(
    inputs: ToolInputsT,
    incoming: ToolStateJobInstanceT,
    state: ToolStateJobInstancePopulatedT,
    errors,
    errors: ParameterValidationErrorsT,
    prefix="",
    context=None,
    check=True,
@@ -569,24 +568,23 @@ def _populate_state_legacy(
        key = prefix + input.name
        group_state = state[input.name]
        group_prefix = f"{key}|"
        if input.type == "repeat":
            repeat_input = cast(Repeat, input)
        if isinstance(input, Repeat):
            rep_index = 0
            del group_state[:]
            while True:
                rep_prefix = f"{key}_{rep_index}"
                rep_min_default = repeat_input.default if repeat_input.default > repeat_input.min else repeat_input.min
                rep_min_default = input.default if input.default > input.min else input.min
                if (
                    not any(incoming_key.startswith(rep_prefix) for incoming_key in incoming.keys())
                    and rep_index >= rep_min_default
                ):
                    break
                if rep_index < repeat_input.max:
                if rep_index < input.max:
                    new_state: ToolStateJobInstancePopulatedT = {"__index__": rep_index}
                    group_state.append(new_state)
                    _populate_state_legacy(
                        request_context,
                        repeat_input.inputs,
                        input.inputs,
                        incoming,
                        new_state,
                        errors,
@@ -596,10 +594,10 @@ def _populate_state_legacy(
                        simple_errors=simple_errors,
                    )
                rep_index += 1
        elif input.type == "conditional":
            conditional_input = cast(Conditional, input)
            test_param = cast(ToolParameter, conditional_input.test_param)
            if conditional_input.value_ref and not conditional_input.value_ref_in_group:
        elif isinstance(input, Conditional):
            test_param = input.test_param
            assert test_param is not None
            if input.value_ref and not input.value_ref_in_group:
                test_param_key = prefix + test_param.name
            else:
                test_param_key = group_prefix + test_param.name
@@ -619,11 +617,11 @@ def _populate_state_legacy(
                errors[test_param_key] = error
            else:
                try:
                    current_case = conditional_input.get_current_case(value)
                    group_state = state[conditional_input.name] = cast(ToolStateJobInstancePopulatedT, {})
                    current_case = input.get_current_case(value)
                    group_state = state[input.name] = cast(ToolStateJobInstancePopulatedT, {})
                    _populate_state_legacy(
                        request_context,
                        conditional_input.cases[current_case].inputs,
                        input.cases[current_case].inputs,
                        incoming,
                        group_state,
                        errors,
@@ -636,11 +634,10 @@ def _populate_state_legacy(
                except Exception:
                    errors[test_param_key] = "The selected case is unavailable/invalid."
            group_state[test_param.name] = value
        elif input.type == "section":
            section_input = cast(Section, input)
        elif isinstance(input, Section):
            _populate_state_legacy(
                request_context,
                section_input.inputs,
                input.inputs,
                incoming,
                group_state,
                errors,
@@ -649,14 +646,13 @@ def _populate_state_legacy(
                check=check,
                simple_errors=simple_errors,
            )
        elif input.type == "upload_dataset":
            dataset_input = cast(UploadDataset, input)
            file_count = dataset_input.get_file_count(request_context, context)
        elif isinstance(input, UploadDataset):
            file_count = input.get_file_count(request_context, context)
            while len(group_state) > file_count:
                del group_state[-1]
            while file_count > len(group_state):
                new_state_upload: ToolStateJobInstancePopulatedT = {"__index__": len(group_state)}
                for upload_item in dataset_input.inputs.values():
                for upload_item in input.inputs.values():
                    new_state_upload[upload_item.name] = upload_item.get_initial_value(request_context, context)
                group_state.append(new_state_upload)
            for rep_index, rep_state in enumerate(group_state):
@@ -664,7 +660,7 @@ def _populate_state_legacy(
                rep_prefix = f"{key}_{rep_index}|"
                _populate_state_legacy(
                    request_context,
                    dataset_input.inputs,
                    input.inputs,
                    incoming,
                    rep_state,
                    errors,
@@ -674,6 +670,7 @@ def _populate_state_legacy(
                    simple_errors=simple_errors,
                )
        else:
            assert isinstance(input, ToolParameter)
            param_value = _get_incoming_value(incoming, key, state.get(input.name))
            value, error = (
                check_param(request_context, input, param_value, context, simple_errors=simple_errors)
+1 −4
Original line number Diff line number Diff line
from typing import (
    Any,
    cast,
    Optional,
    Union,
)
@@ -515,9 +514,7 @@ def from_legacy_install_info(legacy_install_info: LegacyInstallInfoTuple) -> Ins
    extra_info: Union[ExtraRepoInfo, EmptyDict]
    _, repo_metadata_install_info, extra_info = legacy_install_info
    if repo_metadata_install_info:
        metadata_info = RepositoryMetadataInstallInfo.from_legacy_dict(
            cast(RepositoryMetadataInstallInfoDict, repo_metadata_install_info)
        )
        metadata_info = RepositoryMetadataInstallInfo.from_legacy_dict(repo_metadata_install_info)
    else:
        metadata_info = None
    if extra_info: