Unverified Commit ce0d3a75 authored by Marius van den Beek's avatar Marius van den Beek Committed by GitHub
Browse files

Merge pull request #14330 from nsoranzo/release_22.01_dataset_API_fix_extension_list_filter

[22.01] Fix filtering by an extension list in dataset API
parents 80b726eb 8321b869
Loading
Loading
Loading
Loading
+23 −17
Original line number Diff line number Diff line
@@ -977,15 +977,17 @@ class ModelFilterParser(HasAModelManager):
        Set up, extend, or alter `orm_filter_parsers` and `fn_filter_parsers`.
        """
        # note: these are the default filters for all models
        self.orm_filter_parsers.update({
        self.orm_filter_parsers.update(
            {
                # (prob.) applicable to all models
            'id': {'op': ('in')},
            'encoded_id': {'column': 'id', 'op': ('in'), 'val': self.parse_id_list},
                "id": {"op": ("in")},
                "encoded_id": {"column": "id", "op": ("in"), "val": self.parse_id_list},
                # dates can be directly passed through the orm into a filter (no need to parse into datetime object)
            'extension': {'op': ('eq', 'like', 'in')},
            'create_time': {'op': ('le', 'ge', 'lt', 'gt'), 'val': self.parse_date},
            'update_time': {'op': ('le', 'ge', 'lt', 'gt'), 'val': self.parse_date},
        })
                "extension": {"op": ("eq", "like", "in"), "val": {"in": lambda v: v.split(",")}},
                "create_time": {"op": ("le", "ge", "lt", "gt"), "val": self.parse_date},
                "update_time": {"op": ("le", "ge", "lt", "gt"), "val": self.parse_date},
            }
        )

    def build_filter_params(
        self,
@@ -1075,7 +1077,7 @@ class ModelFilterParser(HasAModelManager):
        """
        Attempt to parse a non-ORM filter function.
        """
        # fn_filter_list is a dict: fn_filter_list[ attr ] = { 'opname1' : opfn1, 'opname2' : opfn2, etc. }
        # fn_filter_parsers is a dict: fn_filter_parsers[attr] = {"opname1": opfn1, "opname2": opfn2, etc. }

        # attr, op is a nested dictionary pointing to the filter fn
        attr_map = self.fn_filter_parsers.get(attr, None)
@@ -1095,13 +1097,13 @@ class ModelFilterParser(HasAModelManager):
        return self.parsed_filter(filter_type="function", filter=lambda i: filter_fn(i, val))

    # ---- ORM filters
    def _parse_orm_filter(self, attr, op, val):
    def _parse_orm_filter(self, attr, op, val) -> Optional[ParsedFilter]:
        """
        Attempt to parse a ORM-based filter.

        Using SQLAlchemy, this would yield a sql.elements.BinaryExpression.
        """
        # orm_filter_list is a dict: orm_filter_list[ attr ] = <list of allowed ops>
        # orm_filter_parsers is a dict: orm_filter_parsers[attr] = <column map>
        column_map = self.orm_filter_parsers.get(attr, None)
        if not column_map:
            # no column mapping (not allowlisted)
@@ -1124,16 +1126,20 @@ class ModelFilterParser(HasAModelManager):
        allowed_ops = column_map['op']
        if op not in allowed_ops:
            return None
        op = self._convert_op_string_to_fn(column, op)
        if not op:
        converted_op = self._convert_op_string_to_fn(column, op)
        if not converted_op:
            return None

        # parse the val from string using the 'val' parser if present (otherwise, leave as string)
        val_parser = column_map.get('val', None)
        val_parser = column_map.get("val")
        # val_parser can be a dictionary indexed by the operations, in case different functions
        # need to be called depending on the operation
        if isinstance(val_parser, dict):
            val_parser = val_parser.get(op)
        if val_parser:
            val = val_parser(val)

        orm_filter = op(val)
        orm_filter = converted_op(val)
        return self.parsed_filter(filter_type="orm", filter=orm_filter)

    #: these are the easier/shorter string equivalents to the python operator fn names that need '__' around them
@@ -1171,7 +1177,7 @@ class ModelFilterParser(HasAModelManager):
    # TODO: These should go somewhere central - we've got ~6 parser modules/sections now
    def parse_id_list(self, id_list_string, sep=','):
        """
        Split `id_list_string` at `sep`.
        Split `id_list_string` at `sep` and decode as ids.
        """
        # TODO: move id decoding out
        id_list = [self.app.security.decode_id(id_) for id_ in id_list_string.split(sep)]
+33 −0
Original line number Diff line number Diff line
@@ -86,6 +86,39 @@ class DatasetsApiTestCase(ApiTestCase):
        result = self._get("datasets", payload).json()
        assert len(result) == 0

    def test_search_by_extension(self):
        self.dataset_populator.new_dataset(self.history_id, wait=True)
        payload = {
            "q": ["extension"],
            "qv": ["txt"],
            "history_id": self.history_id,
        }
        assert len(self._get("datasets", payload).json()) == 1
        payload = {
            "q": ["extension"],
            "qv": ["bam"],
            "history_id": self.history_id,
        }
        assert len(self._get("datasets", payload).json()) == 0
        payload = {
            "q": ["extension-in"],
            "qv": ["bam,txt"],
            "history_id": self.history_id,
        }
        assert len(self._get("datasets", payload).json()) == 1
        payload = {
            "q": ["extension-like"],
            "qv": ["t%t"],
            "history_id": self.history_id,
        }
        assert len(self._get("datasets", payload).json()) == 1
        payload = {
            "q": ["extension-like"],
            "qv": ["b%m"],
            "history_id": self.history_id,
        }
        assert len(self._get("datasets", payload).json()) == 0

    def test_invalid_search(self):
        payload = {'limit': 10, 'offset': 0, 'q': ['history_content_type', 'tag-invalid_op'], 'qv': ['dataset', 'notag']}
        index_response = self._get("datasets", payload)