Commit 62ba0ff4 authored by Cage, Gregory's avatar Cage, Gregory
Browse files

Rollback ornl changes and update interactive tool code with upstream

parent dd83a55a
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -1243,10 +1243,10 @@ class MinimalJobWrapper(HasResourceParameters):
            self.__prepare_upload_paramfile(job)

        tool_evaluator = self._get_tool_evaluator(job)
        if hasattr(self.app, "interactivetool_manager"):
            self.interactivetools = tool_evaluator.populate_interactivetools()
            self.app.interactivetool_manager.create_interactivetool(job, self.tool, self.interactivetools)
            job.interactive_url = self.app.interactivetool_manager.get_job_subdomain(job)
        # if hasattr(self.app, "interactivetool_manager"):
        #     self.interactivetools = tool_evaluator.populate_interactivetools()
        #     self.app.interactivetool_manager.create_interactivetool(job, self.tool, self.interactivetools)
        #     job.interactive_url = self.app.interactivetool_manager.get_job_subdomain(job)

        compute_environment = compute_environment or self.default_compute_environment(job)
        if hasattr(self.app, "interactivetool_manager"):
+24 −33
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ class InteractiveToolSqlite:
            conn = sqlite3.connect(self.sqlite_filename)
            try:
                c = conn.cursor()
                select = f"""SELECT token, host, port, info, protocol
                select = f"""SELECT token, host, port, info, protocol,
                            FROM {DATABASE_TABLE_NAME}
                            WHERE key=? and key_type=?"""
                c.execute(
@@ -76,7 +76,7 @@ class InteractiveToolSqlite:
                                  host text,
                                  port integer,
                                  info text,
                                  protocol text,
                                  protocl text,
                                  PRIMARY KEY (key, key_type)
                                  )"""
                        % (DATABASE_TABLE_NAME)
@@ -146,13 +146,13 @@ class InteractiveToolSqlite:
            entry_point.token,
            entry_point.host,
            entry_point.port,
            json.dumps(
            info=json.dumps(
                {
                    "requires_path_in_url": entry_point.requires_path_in_url,
                    "requires_path_in_header_named": entry_point.requires_path_in_header_named,
                }
            ),
            entry_point.protocol,
            protocol=entry_point.protocol,
        )

    def remove_entry_point(self, entry_point):
@@ -184,21 +184,15 @@ class InteractiveToolManager:
                name=entry["name"],
                label=entry["label"],
                requires_domain=entry["requires_domain"],
                protocol=entry["protocol"],
                requires_path_in_url=entry["requires_path_in_url"],
                requires_path_in_header_named=entry["requires_path_in_header_named"],
                protocol=entry["protocol"],
                short_token=None,
            )
            self.sa_session.add(ep)
        if flush:
            with transaction(self.sa_session):
                self.sa_session.commit()

    def get_job_subdomain(self, job):
        # returns the url for the first entry point
        for ep in job.interactivetool_entry_points:
            return self.calculate_entry_point(self.app.security.encode_id, ep)

    def configure_entry_point(self, job, tool_port=None, host=None, port=None, protocol=None):
        return self.configure_entry_points(
            job, {tool_port: dict(tool_port=tool_port, host=host, port=port, protocol=protocol)}
@@ -216,10 +210,10 @@ class InteractiveToolManager:
            else:
                ep.host = port_dict["host"]
                ep.port = port_dict["port"]
                self.save_entry_point(ep)
                ep.protocol = port_dict["protocol"]
                ep.configured = True
                self.sa_session.add(ep)
                self.save_entry_point(ep)
                configured.append(ep)
        if configured:
            with transaction(self.sa_session):
@@ -309,30 +303,26 @@ class InteractiveToolManager:
                self.sa_session.commit()
        self.propagator.remove_entry_point(entry_point)

    def calculate_entry_point(self, encode_id, entry_point):
        entry_point_encoded_id = encode_id(entry_point.id)
        entry_point_class = entry_point.__class__.__name__.lower()
        entry_point_prefix = self.app.config.interactivetools_prefix
        entry_point_token = entry_point.token
        # if self.app.config.interactivetools_shorten_url:
        #     return f"{entry_point_encoded_id}-{entry_point_token[:10]}.{entry_point_prefix}"
        return f"{entry_point_encoded_id}-{entry_point_token}.{entry_point_class}.{entry_point_prefix}"

    def target_if_active(self, trans, entry_point):
        if entry_point.active and not entry_point.deleted:
            request_host = trans.request.host
            if not self.app.config.interactivetools_upstream_proxy and self.app.config.interactivetools_proxy_host:
                request_host = self.app.config.interactivetools_proxy_host
            protocol = trans.request.host_url.split("//", 1)[0]
            use_it_proxy_host_cfg = (
                not self.app.config.interactivetools_upstream_proxy and self.app.config.interactivetools_proxy_host
            )

            url_parts = urlsplit(trans.request.host_url)
            url_host = self.app.config.interactivetools_proxy_host if use_it_proxy_host_cfg else trans.request.host
            url_path = url_parts.path

            if entry_point.requires_domain:
                rval = f"{protocol}//{self.get_entry_point_subdomain(trans, entry_point)}.{request_host}/"
                url_host = f"{self.get_entry_point_subdomain(trans, entry_point)}.{url_host}"
                if entry_point.entry_url:
                    rval = "{}/{}".format(rval.rstrip("/"), entry_point.entry_url.lstrip("/"))
                    url_path = f"{url_path.rstrip('/')}/{entry_point.entry_url.lstrip('/')}"
            else:
                rval = self.get_entry_point_path(trans, entry_point)
                if not self.app.config.interactivetools_upstream_proxy and self.app.config.interactivetools_proxy_host:
                    rval = f"{protocol}//{request_host}{rval}"
            return rval
                url_path = self.get_entry_point_path(trans, entry_point)
                if not use_it_proxy_host_cfg:
                    return url_path

            return urlunsplit((url_parts.scheme, url_host, url_path, "", ""))

    def _get_entry_point_url_elements(self, trans, entry_point):
        encoder = IdAsLowercaseAlphanumEncodingHelper(trans.security)
@@ -343,7 +333,8 @@ class InteractiveToolManager:
        return ep_encoded_id, ep_class_id, ep_prefix, ep_token

    def get_entry_point_subdomain(self, trans, entry_point):
        return self.calculate_entry_point(trans.security.encode_id, entry_point)
        ep_encoded_id, ep_class_id, ep_prefix, ep_token = self._get_entry_point_url_elements(trans, entry_point)
        return f"{ep_encoded_id}-{ep_token}.{ep_class_id}.{ep_prefix}"

    def get_entry_point_path(self, trans, entry_point):
        url_path = "/"
+3 −1
Original line number Diff line number Diff line
@@ -319,13 +319,15 @@ class XmlToolSource(ToolSource):
            protocol = ep_el.attrib.get("protocol", "http")
            if protocol:
                protocol = protocol.strip()
            if requires_path_in_header_named:
                requires_path_in_header_named = requires_path_in_header_named.strip()
            rtt.append(
                dict(
                    port=port,
                    url=url,
                    name=name,
                    label=label,
                    requires_domain=requires_domain,
                    protocol=protocol,
                    requires_path_in_url=requires_path_in_url,
                    requires_path_in_header_named=requires_path_in_header_named,
                )
+5 −5
Original line number Diff line number Diff line
@@ -196,8 +196,8 @@ class ToolEvaluator:
        if self._history:
            param_dict["__history_id__"] = self.app.security.encode_id(self._history.id)
        param_dict["__galaxy_url__"] = self.compute_environment.galaxy_url()
        if hasattr(self.job, "interactive_url") and isinstance(self.job.interactive_url, str):
            param_dict["__tool_url_prefix__"] = ''.join([self.job.interactive_url, ".", urlparse(self.compute_environment.galaxy_url()).hostname])
        # if hasattr(self.job, "interactive_url") and isinstance(self.job.interactive_url, str):
        #     param_dict["__tool_url_prefix__"] = ''.join([self.job.interactive_url, ".", urlparse(self.compute_environment.galaxy_url()).hostname])

        param_dict.update(self.tool.template_macro_params)
        # All parameters go into the param_dict
@@ -667,11 +667,11 @@ class ToolEvaluator:
            elif inject and inject.startswith("oidc_"):
                environment_variable_template = self.get_oidc_token(inject)
                is_template = False
            elif inject and inject == "entry_point_path" and environment_variable_template:
            elif inject and inject == "entry_point_path_for_label" and environment_variable_template:
                from galaxy.managers.interactivetool import InteractiveToolManager

                entry_point_name = environment_variable_template
                matching_eps = [ep for ep in self.job.interactivetool_entry_points if ep.name == entry_point_name]
                entry_point_label = environment_variable_template
                matching_eps = [ep for ep in self.job.interactivetool_entry_points if ep.label == entry_point_label]
                if matching_eps:
                    entry_point = matching_eps[0]
                    entry_point_path = InteractiveToolManager(self.app).get_entry_point_path(self.app, entry_point)