Commit 06245bb7 authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

add option to set rse/scheme for download

parent 043c81c3
Loading
Loading
Loading
Loading
Loading
+102 −47
Original line number Diff line number Diff line
import copy
import time

try:
    from ..authnz.util import provider_name_to_backend
except ImportError:
    provider_name_to_backend = None
    pass
from ..objectstore import ConcreteObjectStore

import os
@@ -197,28 +202,40 @@ def parse_config_xml(config_xml):
            _config_xml_error("extra_dir")
        extra_dirs = [{k: e.get(k) for k in attrs} for e in e_xml]

        attrs = ("rse", "scheme")
        e_xml = config_xml.findall("rucio_download_scheme")
        rucio_download_schemes = []
        if e_xml:
            rucio_download_schemes = [{k: e.get(k) for k in attrs} for e in e_xml]



        oidc_provider = config_xml.findtext("oidc_provider", None)

        e_xml = config_xml.findall("rucio")
        if e_xml:
            rucio_preferred_rse_name = e_xml[0].get("preferred_rse_name", None)
            rucio_preferred_rse_protocol = e_xml[0].get("preferred_rse_protocol", None)
            rucio_write_rse_name = e_xml[0].get("write_rse_name", None)
            rucio_write_rse_scheme = e_xml[0].get("write_rse_scheme", None)
            rucio_scope = e_xml[0].get("scope", None)
            rucio_register_only = string_as_bool(e_xml[0].get("register_only", "False"))
        else:
            rucio_preferred_rse_name = None
            rucio_preferred_rse_protocol = None
            rucio_write_rse_name = None
            rucio_write_rse_scheme = None
            rucio_scope = None
            rucio_register_only = False

            oidc_provider = None
        return {
            "cache": {
                "size": cache_size,
                "path": staging_path,
            },
            "extra_dirs": extra_dirs,
            "rucio_preferred_rse_name": rucio_preferred_rse_name,
            "rucio_preferred_rse_protocol": rucio_preferred_rse_protocol,
            "rucio_write_rse_name": rucio_write_rse_name,
            "rucio_write_rse_scheme": rucio_write_rse_scheme,
            "rucio_scope": rucio_scope,
            "rucio_register_only": rucio_register_only,
            "rucio_download_schemes": rucio_download_schemes,
            "oidc_provider": oidc_provider,
        }
    except Exception:
        # Toss it back up after logging, we can't continue loading at this point.
@@ -228,64 +245,85 @@ def parse_config_xml(config_xml):

class RucioBroker():
    def __init__(self, rucio_config):
        self.rse_name = rucio_config["rucio_preferred_rse_name"]
        self.rse_protocol = rucio_config["rucio_preferred_rse_protocol"]
        self.write_rse_name = rucio_config["rucio_write_rse_name"]
        self.write_rse_scheme = rucio_config["rucio_write_rse_scheme"]
        self.scope = rucio_config["rucio_scope"]
        self.register_only = rucio_config["rucio_register_only"]
        self.download_schemes = rucio_config["rucio_download_schemes"]
        rucio.common.utils.PREFERRED_CHECKSUM = "md5"
        # rucio config is in a system rucio.cfg file
        self.rucio_client = Client()
        self.upload_client = UploadClient(_client=self.rucio_client)
        self.download_client = DownloadClient(client=self.rucio_client)
        self.ingest_client = InPlaceIngestClient(_client=self.rucio_client)

    def get_rucio_client(self):
        client = Client()
        return client

    def get_rucio_upload_client(self, auth_token=None):
        client = self.get_rucio_client()
        uc = UploadClient(_client=client)
        uc.auth_token = auth_token
        return uc

    def get_rucio_download_client(self, auth_token=None):
        client = self.get_rucio_client()
        dc = DownloadClient(client=client)
        dc.auth_token = auth_token
        return dc

    def get_rucio_ingest_client(self, auth_token=None):
        client = self.get_rucio_client()
        ic = InPlaceIngestClient(_client=client)
        ic.auth_token = auth_token
        return ic

    def register(self, key, source_path):
        key = os.path.basename(key)
        item = {
            "path": source_path,
            "rse": self.rse_name,
            "rse": self.write_rse_name,
            "did_scope": self.scope,
            "did_name": key,
            "pfn": f"file://localhost/{source_path}",
        }
        items = [item]
        self.ingest_client.ingest(items)
        self.get_rucio_ingest_client().ingest(items)

    def upload(self, key, source_path):
        key = os.path.basename(key)
        item = {
            "path": source_path,
            "rse": self.rse_name,
            "rse": self.write_rse_name,
            "did_scope": self.scope,
            "did_name": key,
            "impl": self.rse_protocol,
            "force_scheme": self.write_rse_scheme,
        }
        items = [item]
        self.upload_client.upload(items)
        self.get_rucio_upload_client().upload(items)

    def download(self, key, dest_path):
    def download(self, key, dest_path, auth_token):
        key = os.path.basename(key)
        base_dir = os.path.dirname(dest_path)
        dids = [{"scope": self.scope, "name": key}]
        try:
            repl = next(self.rucio_client.list_replicas(dids))["rses"].keys()
            if self.rse_name in repl:
            repl = next(self.get_rucio_client().list_replicas(dids))["rses"].keys()
            item = None
            for rse_scheme in self.download_schemes:
                if rse_scheme['rse'] in repl:
                    item = {
                        "did": f"{self.scope}:{key}",
                    "impl": self.rse_protocol,
                    "rse": self.rse_name,
                        "force_scheme": rse_scheme['scheme'],
                        "rse": rse_scheme['rse'],
                        "base_dir": base_dir,
                        "no_subdir": True,
                    }
            else:
                    break
            if item is None:
                item = {
                    "did": f"{self.scope}:{key}",
                    "base_dir": base_dir,
                    "no_subdir": True,
                }

            items = [item]
            download_client = DownloadClient(client=self.rucio_client)
            download_client = self.get_rucio_download_client(auth_token=auth_token)
            download_client.download_dids(items)
        except Exception as e:
            log.exception("Cannot download file:" + str(e))
@@ -296,7 +334,7 @@ class RucioBroker():
        key = os.path.basename(key)
        dids = [{"scope": self.scope, "name": key}]
        try:
            repl = next(self.rucio_client.list_replicas(dids))
            repl = next(self.get_rucio_client().list_replicas(dids))
            return "AVAILABLE" in repl['states'].values()
        except:
            return False
@@ -305,17 +343,18 @@ class RucioBroker():
        key = os.path.basename(key)
        dids = [{"scope": self.scope, "name": key}]
        try:
            repl = next(self.rucio_client.list_replicas(dids))
            repl = next(self.get_rucio_client().list_replicas(dids))
            return repl['bytes']
        except:
            return 0

    def delete(self, key):
        rucio_client = self.get_rucio_client()
        key = os.path.basename(key)
        dids = [{"scope": self.scope, "name": key}]
        rses = next(self.rucio_client.list_replicas(dids))["rses"].keys()
        rses = next(rucio_client.list_replicas(dids))["rses"].keys()
        for rse in rses:
            self.rucio_client.delete_replicas(rse, dids)
            rucio_client.delete_replicas(rse, dids)


class RucioObjectStore(ConcreteObjectStore):
@@ -338,18 +377,19 @@ class RucioObjectStore(ConcreteObjectStore):
    def __init__(self, config, config_dict):
        super().__init__(config, config_dict)
        self.rucio_config = {}
        self.rucio_config["rucio_preferred_rse_name"] = config_dict.get("rucio_preferred_rse_name", None)
        self.rucio_config["rucio_preferred_rse_protocol"] = config_dict.get("rucio_preferred_rse_protocol", None)
        self.rucio_config["rucio_write_rse_name"] = config_dict.get("rucio_write_rse_name", None)
        self.rucio_config["rucio_write_rse_scheme"] = config_dict.get("rucio_write_rse_scheme", None)
        self.rucio_config["rucio_register_only"] = config_dict.get("rucio_register_only", False)
        self.rucio_config["rucio_scope"] = config_dict.get("rucio_scope", None)
        self.rucio_config["rucio_download_schemes"] = config_dict.get("rucio_download_schemes", [])

        if 'RUCIO_PREFERRED_RSE_NAME' in os.environ:
            self.rucio_config["rucio_preferred_rse_name"] = os.environ['RUCIO_PREFERRED_RSE_NAME']
        if 'RUCIO_PREFERRED_RSE_PROTOCOL' in os.environ:
            self.rucio_config["rucio_preferred_rse_protocol"] = os.environ['RUCIO_PREFERRED_RSE_PROTOCOL']
        if 'RUCIO_WRITE_RSE_NAME' in os.environ:
            self.rucio_config["rucio_write_rse_name"] = os.environ['RUCIO_WRITE_RSE_NAME']
        if 'RUCIO_WRITE_RSE_SCHEME' in os.environ:
            self.rucio_config["rucio_write_rse_scheme"] = os.environ['RUCIO_WRITE_RSE_SCHEME']
        if 'RUCIO_REGISTER_ONLY' in os.environ:
            self.rucio_config["rucio_register_only"] = string_as_bool(os.environ['RUCIO_REGISTER_ONLY'])

        self.oidc_provider = config_dict.get("oidc_provider", None)
        self.rucio_broker = RucioBroker(self.rucio_config)
        cache_dict = config_dict["cache"]
        if cache_dict is None:
@@ -417,7 +457,7 @@ class RucioObjectStore(ConcreteObjectStore):
    def _get_cache_path(self, rel_path):
        return os.path.abspath(os.path.join(self.staging_path, rel_path))

    def _pull_into_cache(self, rel_path):
    def _pull_into_cache(self, rel_path, auth_token):
        log.debug("rucio _pull_into_cache: " + rel_path)
        # Ensure the cache directory structure exists (e.g., dataset_#_files/)
        rel_path_dir = os.path.dirname(rel_path)
@@ -425,7 +465,7 @@ class RucioObjectStore(ConcreteObjectStore):
            os.makedirs(self._get_cache_path(rel_path_dir), exist_ok=True)
        # Now pull in the file
        dest = self._get_cache_path(rel_path)
        file_ok = self.rucio_broker.download(rel_path, dest)
        file_ok = self.rucio_broker.download(rel_path, dest, auth_token)
        self._fix_permissions(self._get_cache_path(rel_path_dir))
        return file_ok

@@ -570,9 +610,10 @@ class RucioObjectStore(ConcreteObjectStore):
    def _get_data(self, obj, start=0, count=-1, **kwargs):
        rel_path = self._construct_path(obj, **kwargs)
        log.debug("rucio _get_data: " + rel_path)
        auth_token = self._get_token(**kwargs)
        # Check cache first and get file if not there
        if not self._in_cache(rel_path) or os.path.getsize(self._get_cache_path(rel_path)) == 0:
            self._pull_into_cache(rel_path)
            self._pull_into_cache(rel_path, auth_token)
        # Read the file content from cache
        data_file = open(self._get_cache_path(rel_path))
        data_file.seek(start)
@@ -580,9 +621,23 @@ class RucioObjectStore(ConcreteObjectStore):
        data_file.close()
        return content

    def _get_token(self, **kwargs):
        auth_token = kwargs.get("auth_token", None)
        if auth_token:
            return auth_token
        try:
            trans = kwargs.get("trans", None)
            backend = provider_name_to_backend(self.oidc_provider)
            tokens = trans.user.get_oidc_tokens(backend)
            return tokens["id"]
        except Exception as e:
            log.debug("Failed to get auth token: %s", e)
            return None

    def _update_cache(self, obj, **kwargs):
        base_dir = kwargs.get("base_dir", None)
        dir_only = kwargs.get("dir_only", False)
        auth_token = self._get_token(**kwargs)
        rel_path = self._construct_path(obj, **kwargs)
        log.debug("rucio _update_cache: " + rel_path)

@@ -611,7 +666,7 @@ class RucioObjectStore(ConcreteObjectStore):
            if dir_only:  # Directories do not get pulled into cache
                return cache_path
            else:
                if self._pull_into_cache(rel_path):
                if self._pull_into_cache(rel_path, auth_token):
                    return cache_path
        raise ObjectNotFound(f"objectstore.get_filename, no cache_path: {obj}, kwargs: {kwargs}")

@@ -673,7 +728,7 @@ class RucioObjectStore(ConcreteObjectStore):
        # Update the file on rucio
        self.rucio_broker.upload(rel_path, source_file)

    def _get_store_usage_percent(self):
    def _get_store_usage_percent(self, **kwargs):
        log.debug("rucio _get_store_usage_percent, not implemented yet")
        return 0.0