Unverified Commit b731dde2 authored by John Davis's avatar John Davis Committed by GitHub
Browse files

Merge pull request #19925 from jdavcs/24.1_19880

[24.1] Handle special charater in raw SQL
parents 953a9723 1bf72997
Loading
Loading
Loading
Loading
+14 −6
Original line number Diff line number Diff line
@@ -225,6 +225,12 @@ from galaxy.util.json import safe_loads
from galaxy.util.sanitize_html import sanitize_html

if TYPE_CHECKING:
    from sqlalchemy.sql.expression import BindParameter

    from galaxy.objectstore import (
        BaseObjectStore,
        QuotaSourceMap,
    )
    from galaxy.schema.invocation import InvocationMessageUnion

log = logging.getLogger(__name__)
@@ -652,9 +658,9 @@ WHERE dataset.id IN (SELECT dataset_id FROM per_hist_hdas)
"""


def calculate_user_disk_usage_statements(user_id, quota_source_map, for_sqlite=False):
def calculate_user_disk_usage_statements(user_id: int, quota_source_map: "QuotaSourceMap", for_sqlite: bool = False):
    """Standalone function so can be reused for postgres directly in pgcleanup.py."""
    statements = []
    statements: List[Tuple[str, Dict[str, Any]]] = []
    default_quota_enabled = quota_source_map.default_quota_enabled
    default_exclude_ids = quota_source_map.default_usage_excluded_ids()
    default_cond = "dataset.object_store_id IS NULL" if default_quota_enabled and default_exclude_ids else ""
@@ -670,7 +676,7 @@ def calculate_user_disk_usage_statements(user_id, quota_source_map, for_sqlite=F
UPDATE galaxy_user SET disk_usage = ({default_usage})
WHERE id = :id
"""
    params = {"id": user_id}
    params: Dict[str, Any] = {"id": user_id}
    if default_exclude_ids:
        params["exclude_object_store_ids"] = default_exclude_ids
    statements.append((default_usage, params))
@@ -1147,13 +1153,13 @@ ON CONFLICT
        usage = sa_session.scalar(sql_calc, params)
        return usage

    def calculate_and_set_disk_usage(self, object_store):
    def calculate_and_set_disk_usage(self, object_store: "BaseObjectStore"):
        """
        Calculates and sets user disk usage.
        """
        self._calculate_or_set_disk_usage(object_store=object_store)

    def _calculate_or_set_disk_usage(self, object_store):
    def _calculate_or_set_disk_usage(self, object_store: "BaseObjectStore"):
        """
        Utility to calculate and return the disk usage.  If dryrun is False,
        the new value is set immediately.
@@ -1161,11 +1167,13 @@ ON CONFLICT
        assert object_store is not None
        quota_source_map = object_store.get_quota_source_map()
        sa_session = object_session(self)
        assert sa_session
        assert sa_session.bind
        for_sqlite = "sqlite" in sa_session.bind.dialect.name
        statements = calculate_user_disk_usage_statements(self.id, quota_source_map, for_sqlite)
        for sql, args in statements:
            statement = text(sql)
            binds = []
            binds: List[BindParameter] = []
            for key, _ in args.items():
                expand_binding = key.endswith("s")
                binds.append(bindparam(key, expanding=expand_binding))
+2 −2
Original line number Diff line number Diff line
@@ -368,7 +368,7 @@ class ObjectStore(metaclass=abc.ABCMeta):
        raise NotImplementedError()

    @abc.abstractmethod
    def get_quota_source_map(self):
    def get_quota_source_map(self) -> "QuotaSourceMap":
        """Return QuotaSourceMap describing mapping of object store IDs to quota sources."""

    @abc.abstractmethod
@@ -523,7 +523,7 @@ class BaseObjectStore(ObjectStore):
            badges.append({"type": type, "message": message})
        return badges

    def get_quota_source_map(self):
    def get_quota_source_map(self) -> "QuotaSourceMap":
        # I'd rather keep this abstract... but register_singleton wants it to be instantiable...
        raise NotImplementedError()

+2 −1
Original line number Diff line number Diff line
@@ -401,7 +401,8 @@ class RequiresDiskUsageRecalculation:
            statements = calculate_user_disk_usage_statements(user_id, quota_source_map)

            for sql, args in statements:
                sql, _ = re.subn(r"\:([\w]+)", r"%(\1)s", sql)
                sql = sql.replace("%", "%%")
                sql = re.sub(r"\:([\w]+)", r"%(\1)s", sql)
                new_args = {}
                for key, val in args.items():
                    if isinstance(val, list):