Commit e6aee8b4 authored by John Chilton's avatar John Chilton
Browse files

Fix assertion model linter.

parent a0775063
Loading
Loading
Loading
Loading
+1466 −12

File changed.

Preview size limit exceeded, changes collapsed.

+55 −4
Original line number Diff line number Diff line
@@ -95,7 +95,7 @@ def check_regex(v: typing.Any):
def check_non_negative_if_set(v: typing.Any):
    if v is not None:
        try:
            assert v >= 0
            assert float(v) >= 0
        except TypeError:
            raise AssertionError(f"Invalid type found {v}")
    return v
@@ -137,6 +137,31 @@ class base_{{assertion.name}}_model(AssertionModel):
{% endif %}


class base_{{assertion.name}}_model_relaxed(AssertionModel):
    '''base model for {{assertion.name}} describing attributes.'''
{% for parameter in assertion.parameters %}
{% if not parameter.is_deprecated %}
    {{ parameter.name }}: {{ parameter.lax_type_str }} = Field(
        {{ parameter.field_default_str }},
        description={{ assertion.name }}_{{ parameter.name }}_description,
    )
{% endif %}
{% endfor %}
{% if assertion.children in ["required", "allowed"] %}
    children: typing.Optional["assertion_list"] = None
    asserts: typing.Optional["assertion_list"] = None

{% if assertion.children == "required" %}
    @model_validator(mode='before')
    @classmethod
    def validate_children(self, data: typing.Any):
        if isinstance(data, dict) and 'children' not in data and 'asserts' not in data:
            raise ValueError("At least one of 'children' or 'asserts' must be specified for this assertion type.")
        return data
{% endif %}
{% endif %}


class {{assertion.name}}_model(base_{{assertion.name}}_model):
    r\"\"\"{{ assertion.docstring }}\"\"\"
    that: Literal["{{assertion.name}}"] = "{{assertion.name}}"
@@ -144,6 +169,10 @@ class {{assertion.name}}_model(base_{{assertion.name}}_model):
class {{assertion.name}}_model_nested(AssertionModel):
    r\"\"\"Nested version of this assertion model.\"\"\"
    {{assertion.name}}: base_{{assertion.name}}_model

class {{assertion.name}}_model_relaxed(base_{{assertion.name}}_model_relaxed):
    r\"\"\"{{ assertion.docstring }}\"\"\"
    that: Literal["{{assertion.name}}"] = "{{assertion.name}}"
{% endfor %}

any_assertion_model_flat = Annotated[typing.Union[
@@ -158,8 +187,17 @@ any_assertion_model_nested = typing.Union[
{% endfor %}
]

any_assertion_model_flat_relaxed = Annotated[typing.Union[
{% for assertion in assertions %}
    {{assertion.name}}_model_relaxed,
{% endfor %}
], Field(discriminator="that")]

assertion_list = RootModel[typing.List[typing.Union[any_assertion_model_flat, any_assertion_model_nested]]]

# used to model what the XML conversion should look like - not meant to be consumed outside of
# of Galaxy internals / linting.
relaxed_assertion_list = RootModel[typing.List[any_assertion_model_flat_relaxed]]

class assertion_dict(AssertionModel):
{% for assertion in assertions %}
@@ -310,7 +348,20 @@ class AssertionParameter:

    @property
    def type_str(self) -> str:
        raw_type_str = as_type_str(self.type)
        raw_type_str = as_type_str(self.type, strict=True)
        validators = self.validators[:]
        if self.xml_type_str == "Bytes":
            validators.append("check_bytes")
            validators.append("check_non_negative_if_int")
        if len(validators) > 0:
            validation_str = ",".join([f"BeforeValidator({v})" for v in validators])
            return f"Annotated[{raw_type_str}, {validation_str}]"

        return raw_type_str

    @property
    def lax_type_str(self) -> str:
        raw_type_str = as_type_str(self.type, strict=False)
        validators = self.validators[:]
        if self.xml_type_str == "Bytes":
            validators.append("check_bytes")
@@ -395,11 +446,11 @@ def as_xml_type(target_type) -> str:
    return "xs:string"


def as_type_str(target_type):
def as_type_str(target_type, strict=True):
    if get_origin(target_type) is Annotated:
        args = get_args(target_type)
        if len(args) > 1:
            if args[1].json_type:
            if args[1].json_type and strict:
                return args[1].json_type

        return as_type_str(args[0])
+21 −0
Original line number Diff line number Diff line
@@ -29,6 +29,8 @@ from galaxy.tool_util.parser.util import (
    boolean_true_and_false_values,
    parse_tool_version_with_defaults,
)
from galaxy.tool_util.parser.xml import __parse_assert_list_from_elem
from galaxy.tool_util.verify.assertion_models import relaxed_assertion_list
from galaxy.tool_util.verify.interactor import (
    InvalidToolTestDict,
    ToolTestDescription,
@@ -587,3 +589,22 @@ def split_if_str(value):
    if split:
        value = value.split(",")
    return value


# convert the sort internal structure used by the tool library {tag: string, attributes: dict, children: []}
# into the YAML structure consumed by the test framework {that: string, **atributes}
def tag_structure_to_that_structure(raw_assert):
    as_json = {"that": raw_assert["tag"], **raw_assert.get("attributes", {})}
    children = raw_assert.get("children")
    if children:
        as_json["children"] = list(map(tag_structure_to_that_structure, children))
    return as_json


def assertion_xml_els_to_models(asserts_raw) -> relaxed_assertion_list:
    asserts_raw = __parse_assert_list_from_elem(asserts_raw)

    to_yaml_assertions = []
    for raw_assert in asserts_raw or []:
        to_yaml_assertions.append(tag_structure_to_that_structure(raw_assert))
    return relaxed_assertion_list.model_validate(to_yaml_assertions)
+11 −2
Original line number Diff line number Diff line
import sys
from string import Template

import lxml.etree as ET
import pytest
from pydantic import ValidationError

from galaxy.tool_util.verify.assertion_models import assertion_list
from galaxy.tool_util.verify.codegen import galaxy_xsd_path
from galaxy.tool_util.verify.parse import assertion_xml_els_to_models
from galaxy.util.commands import shell
from galaxy.util.unittest_utils import skip_unless_executable

@@ -92,6 +94,8 @@ valid_xml_assertions = [
    """<has_n_columns n="30" />""",
    """<has_n_columns n="30" delta="4" />""",
    """<has_n_columns n="30" delta="4" sep=" " comment="###" />""",
    """<has_image_width min="500" />""",
    """<has_image_height min="500" />""",
]

invalid_assertions = [
@@ -207,11 +211,11 @@ if sys.version_info < (3, 8): # noqa: UP036
    pytest.skip(reason="Pydantic assertion models require python3.8 or higher", allow_module_level=True)


def test_valid_models_validate():
def test_valid_json_models_validate():
    assertion_list.model_validate(valid_assertions)


def test_invalid_models_do_not_validate():
def test_invalid_json_models_do_not_validate():
    for invalid_assertion in invalid_assertions:
        with pytest.raises(ValidationError):
            assertion_list.model_validate([invalid_assertion])
@@ -235,3 +239,8 @@ def test_invalid_xsd(tmp_path):
        tool_path.write_text(tool_xml)
        ret = shell(["xmllint", "--nowarning", "--noout", "--schema", galaxy_xsd_path, str(tool_path)])
        assert ret != 0, f"{assertion_xml} validated when error expected"


def test_valid_xml_models_validate_after_json_transform():
    for assertion_xml in valid_xml_assertions:
        assertion_xml_els_to_models([ET.fromstring(assertion_xml)])