Unverified Commit a2e598b1 authored by Marius van den Beek's avatar Marius van den Beek Committed by GitHub
Browse files

Merge pull request #19568 from nuwang/overridable_k8s_user_group_ids

[24.2] Make k8s user and group ids overriddable per job
parents 2832585a 3090cef4
Loading
Loading
Loading
Loading
+28 −75
Original line number Diff line number Diff line
@@ -139,10 +139,6 @@ class KubernetesJobRunner(AsynchronousJobRunner):
        self._pykube_api = pykube_client_from_dict(self.runner_params)
        self._galaxy_instance_id = self.__get_galaxy_instance_id()

        self._run_as_user_id = self.__get_run_as_user_id()
        self._run_as_group_id = self.__get_run_as_group_id()
        self._supplemental_group = self.__get_supplemental_group()
        self._fs_group = self.__get_fs_group()
        self._default_pull_policy = self.__get_pull_policy()

        self.setup_base_volumes()
@@ -271,66 +267,29 @@ class KubernetesJobRunner(AsynchronousJobRunner):
        ingress.create()

    def __get_overridable_params(self, job_wrapper, param_key):
        dest_params = self.__get_destination_params(job_wrapper)
        return dest_params.get(param_key, self.runner_params[param_key])
        try:
            return job_wrapper.job_destination.params[param_key]
        except KeyError:
            return self.runner_params[param_key]

    def __get_pull_policy(self):
        return pull_policy(self.runner_params)

    def __get_run_as_user_id(self):
        if self.runner_params.get("k8s_run_as_user_id") or self.runner_params.get("k8s_run_as_user_id") == 0:
            run_as_user = self.runner_params["k8s_run_as_user_id"]
            if run_as_user == "$uid":
    def __get_user_group_param_or_default(self, job_wrapper, param_name):
        substitutable_user_group_id = self.__get_overridable_params(job_wrapper, param_name)
        if substitutable_user_group_id or substitutable_user_group_id == 0:
            if substitutable_user_group_id == "$uid":
                return os.getuid()
            else:
                try:
                    return int(self.runner_params["k8s_run_as_user_id"])
                except Exception:
                    log.warning(
                        'User ID passed for Kubernetes runner needs to be an integer or "$uid", value %s passed is invalid',
                        self.runner_params["k8s_run_as_user_id"],
                    )
                    return None
        return None

    def __get_run_as_group_id(self):
        if self.runner_params.get("k8s_run_as_group_id") or self.runner_params.get("k8s_run_as_group_id") == 0:
            run_as_group = self.runner_params["k8s_run_as_group_id"]
            if run_as_group == "$gid":
            elif substitutable_user_group_id == "$gid":
                return self.app.config.gid
            else:
                try:
                    return int(self.runner_params["k8s_run_as_group_id"])
                except Exception:
                    log.warning(
                        'Group ID passed for Kubernetes runner needs to be an integer or "$gid", value %s passed is invalid',
                        self.runner_params["k8s_run_as_group_id"],
                    )
        return None

    def __get_supplemental_group(self):
        if (
            self.runner_params.get("k8s_supplemental_group_id")
            or self.runner_params.get("k8s_supplemental_group_id") == 0
        ):
            try:
                return int(self.runner_params["k8s_supplemental_group_id"])
                    return int(substitutable_user_group_id)
                except Exception:
                    log.warning(
                    'Supplemental group passed for Kubernetes runner needs to be an integer or "$gid", value %s passed is invalid',
                    self.runner_params["k8s_supplemental_group_id"],
                )
                return None
        return None

    def __get_fs_group(self):
        if self.runner_params.get("k8s_fs_group_id") or self.runner_params.get("k8s_fs_group_id") == 0:
            try:
                return int(self.runner_params["k8s_fs_group_id"])
            except Exception:
                log.warning(
                    'FS group passed for Kubernetes runner needs to be an integer or "$gid", value %s passed is invalid',
                    self.runner_params["k8s_fs_group_id"],
                        'param %s passed to Kubernetes runner needs to be an integer or the strings "$uid" or "$gid". Value %s is invalid',
                        param_name,
                        substitutable_user_group_id,
                    )
                    return None
        return None
@@ -406,7 +365,7 @@ class KubernetesJobRunner(AsynchronousJobRunner):
        }
        # TODO include other relevant elements that people might want to use from
        # TODO http://kubernetes.io/docs/api-reference/v1/definitions/#_v1_podspec
        k8s_spec_template["spec"]["securityContext"] = self.__get_k8s_security_context()
        k8s_spec_template["spec"]["securityContext"] = self.__get_k8s_security_context(ajs.job_wrapper)
        extra_metadata = self.runner_params["k8s_job_metadata"] or "{}"
        if isinstance(extra_metadata, str):
            extra_metadata = yaml.safe_load(extra_metadata)
@@ -554,16 +513,20 @@ class KubernetesJobRunner(AsynchronousJobRunner):
            k8s_spec_template["metadata"]["annotations"].update(new_ann)
        return k8s_spec_template

    def __get_k8s_security_context(self):
    def __get_k8s_security_context(self, job_wrapper):
        security_context = {}
        if self._run_as_user_id or self._run_as_user_id == 0:
            security_context["runAsUser"] = self._run_as_user_id
        if self._run_as_group_id or self._run_as_group_id == 0:
            security_context["runAsGroup"] = self._run_as_group_id
        if self._supplemental_group and self._supplemental_group > 0:
            security_context["supplementalGroups"] = [self._supplemental_group]
        if self._fs_group and self._fs_group > 0:
            security_context["fsGroup"] = self._fs_group
        run_as_user_id = self.__get_user_group_param_or_default(job_wrapper, "k8s_run_as_user_id")
        run_as_group_id = self.__get_user_group_param_or_default(job_wrapper, "k8s_run_as_group_id")
        supplemental_group = self.__get_user_group_param_or_default(job_wrapper, "k8s_supplemental_group_id")
        fs_group = self.__get_user_group_param_or_default(job_wrapper, "k8s_fs_group_id")
        if run_as_user_id or run_as_user_id == 0:
            security_context["runAsUser"] = run_as_user_id
        if run_as_group_id or run_as_group_id == 0:
            security_context["runAsGroup"] = run_as_group_id
        if supplemental_group and supplemental_group > 0:
            security_context["supplementalGroups"] = [supplemental_group]
        if fs_group and fs_group > 0:
            security_context["fsGroup"] = fs_group
        return security_context

    def __get_k8s_restart_policy(self, job_wrapper):
@@ -749,16 +712,6 @@ class KubernetesJobRunner(AsynchronousJobRunner):
    def __get_k8s_job_name(self, prefix, job_wrapper):
        return f"{prefix}-{self.__force_label_conformity(job_wrapper.get_id_tag())}"

    def __get_destination_params(self, job_wrapper):
        """Obtains allowable runner param overrides from the destination"""
        job_destination = job_wrapper.job_destination
        OVERRIDABLE_PARAMS = ["k8s_node_selector", "k8s_affinity", "k8s_extra_job_envs"]
        new_params = {}
        for each_param in OVERRIDABLE_PARAMS:
            if each_param in job_destination.params:
                new_params[each_param] = job_destination.params[each_param]
        return new_params

    def check_watched_item(self, job_state):
        """Checks the state of a job already submitted on k8s. Job state is an AsynchronousJobState"""
        jobs = find_job_object_by_name(self._pykube_api, job_state.job_id, self.runner_params["k8s_namespace"])