Unverified Commit 615d1c79 authored by mvdbeek's avatar mvdbeek
Browse files

Make search work on sqlite and postgres

parent 2d98fbf2
Loading
Loading
Loading
Loading
+23 −33
Original line number Diff line number Diff line
@@ -366,6 +366,7 @@ class JobSearch:
        id_encoding_helper: IdEncodingHelper,
    ):
        self.sa_session = sa_session
        self.dialect_name = sa_session.get_bind().dialect.name
        self.hda_manager = hda_manager
        self.dataset_collection_manager = dataset_collection_manager
        self.ldda_manager = ldda_manager
@@ -747,6 +748,12 @@ class JobSearch:
        used_ids.append(labeled_col)
        return stmt

    def agg_expression(self, column):
        if self.dialect_name == "sqlite":
            return func.group_concat(column)
        else:
            return func.array_agg(column, order_by=column)

    def _build_stmt_for_hdca(
        self, stmt, data_conditions, used_ids, k, v, user_id, value_index, require_name_match=True
    ):
@@ -853,13 +860,7 @@ class JobSearch:
        # This CTE aggregates the path signature strings of the reference HDCA into a
        # canonical, sorted array. This array represents the complete "signature" of the collection.
        reference_full_signature_cte = (
            select(
                func.array_agg(
                    sqlalchemy.text(
                        f"{signature_elements_cte.c.path_signature_string.name} ORDER BY {signature_elements_cte.c.path_signature_string.name}"
                    )
                ).label("signature_array")
            )
            select(self.agg_expression(signature_elements_cte.c.path_signature_string).label("signature_array"))
            .select_from(signature_elements_cte)
            .cte(f"reference_full_signature_{k}_{value_index}")
        )
@@ -946,11 +947,9 @@ class JobSearch:
        candidate_full_signatures_cte = (
            select(
                candidate_signature_elements_cte.c.candidate_hdca_id,
                func.array_agg(
                    sqlalchemy.text(
                        f"{candidate_signature_elements_cte.c.path_signature_string.name} ORDER BY {candidate_signature_elements_cte.c.path_signature_string.name}"
                    )
                ).label("full_signature_array"),
                self.agg_expression(candidate_signature_elements_cte.c.path_signature_string).label(
                    "full_signature_array"
                ),
            )
            .select_from(candidate_signature_elements_cte)
            .group_by(candidate_signature_elements_cte.c.candidate_hdca_id)
@@ -1099,16 +1098,12 @@ class JobSearch:
            # used for direct comparison with candidate collections.
            reference_full_signature_cte = (
                select(
                    func.array_agg(
                        reference_dce_signature_elements_cte.c.path_signature_string,
                        order_by=reference_dce_signature_elements_cte.c.path_signature_string,
                    ).label("signature_array"),
                    func.array_agg(
                        reference_dce_signature_elements_cte.c.raw_dataset_id_for_ordering.cast(
                            sqlalchemy.Integer
                        ),  # Cast to Integer here
                        order_by=reference_dce_signature_elements_cte.c.path_signature_string,  # Order by full path to ensure consistency
                    ).label("ordered_dataset_id_array"),
                    self.agg_expression(reference_dce_signature_elements_cte.c.path_signature_string).label(
                        "signature_array"
                    ),
                    self.agg_expression(reference_dce_signature_elements_cte.c.raw_dataset_id_for_ordering).label(
                        "ordered_dataset_id_array"
                    ),
                    func.count(reference_dce_signature_elements_cte.c.path_signature_string).label(
                        "element_count"
                    ),  # Count elements based on path_signature_string
@@ -1227,12 +1222,9 @@ class JobSearch:
                select(
                    candidate_dce_signature_elements_cte.c.candidate_dce_id,
                    # Corrected array_agg syntax: pass column directly, use order_by keyword
                    func.array_agg(
                        candidate_dce_signature_elements_cte.c.dataset_id_for_ordered_array.cast(
                            sqlalchemy.Integer
                        ),  # Cast explicitly
                        order_by=candidate_dce_signature_elements_cte.c.path_signature_string,  # Order by the full path
                    ).label("candidate_ordered_dataset_ids_array"),
                    self.agg_expression(candidate_dce_signature_elements_cte.c.dataset_id_for_ordered_array).label(
                        "candidate_ordered_dataset_ids_array"
                    ),
                    func.count(candidate_dce_signature_elements_cte.c.candidate_dce_id).label(
                        "candidate_element_count"
                    ),
@@ -1269,11 +1261,9 @@ class JobSearch:
            final_candidate_signatures_cte = (
                select(
                    candidate_dce_signature_elements_cte.c.candidate_dce_id,
                    func.array_agg(
                        sqlalchemy.text(
                            f"{candidate_dce_signature_elements_cte.c.path_signature_string} ORDER BY {candidate_dce_signature_elements_cte.c.path_signature_string.name}"
                        )
                    ).label("full_signature_array"),
                    self.agg_expression(candidate_dce_signature_elements_cte.c.path_signature_string).label(
                        "full_signature_array"
                    ),
                )
                .select_from(candidate_dce_signature_elements_cte)
                .where(