Commit 985537c9 authored by Price, Zach's avatar Price, Zach
Browse files

Update library client to handle token auth

parent 5419789d
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ def download_file(file, dest, token):
            'unit_scale': True,
            'unit_divisor': 1024,
        }
        with tqdm.wrapattr(open(final_path, mode='wb'), 'write', *tqdm_args) as f:
        with tqdm.wrapattr(open(final_path, mode='wb'), 'write', **tqdm_args) as f:
            for chunk in r.iter_content(chunk_size=config.download_chunk_size):
                f.write(chunk)

+56 −8
Original line number Diff line number Diff line
import errno
import logging
import os
from collections import namedtuple

from fsspec.implementations.http import HTTPFileSystem, HTTPFile
import requests
from aiohttp import BasicAuth
from fsspec.asyn import sync, sync_wrapper
from fsspec.caching import AllBytes
from fsspec.implementations.http import HTTPFile, HTTPFileSystem

from . import config
import logging
from fsspec.asyn import sync_wrapper, sync

logger = logging.getLogger(__name__)


class BearerAuth(namedtuple("BearerAuth", ["token", "encoding"])):
    """Http basic authentication helper."""

    def __new__(cls, token: str = "", encoding: str = "latin1") -> "BearerAuth":
        if not token:
            raise ValueError("Token must be specified")

        return super().__new__(cls, token, encoding)

    def encode(self) -> str:
        """Encode credentials."""
        return f"Bearer {self.token}"

    @classmethod
    def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
        """Create a BasicAuth object from an Authorization HTTP header."""
        try:
            auth_type, token = auth_header.split(" ", 1)
        except ValueError:
            raise ValueError("Could not parse authorization header.")

        if auth_type.lower() != "bearer":
            raise ValueError("Unknown authorization method %s" % auth_type)

        return cls(token, encoding=encoding)


class ARMFs(HTTPFileSystem):

    def __init__(self, block_size=None, cache_type="bytes", cache_options=None, asynchronous=False, loop=None, client_kwargs=None, **storage_options):
    def __init__(self, username='', password='', block_size=None, cache_type="bytes", cache_options=None, asynchronous=False, loop=None, **storage_options):
        self.kwargs = storage_options

        def deny(*args, **kwargs):
@@ -22,6 +55,21 @@ class ARMFs(HTTPFileSystem):
        for func in denied_functions:
            setattr(self, func, deny)

        env_token = os.environ.get('ADL_TOKEN')
        if self.kwargs['username'] and self.kwargs['password']:
            self.auth_data = requests.post(
                f'{str(config.endpoint)}/get_token',
                auth=(self.kwargs['username'], self.kwargs['password'])
            ).json()
        elif env_token:
            self.auth_data['id_token'] = env_token
        else:
            raise ValueError(
                'If both the username and password parameters are not set, you'
                ' must provide a valid auth token via the ADL_TOKEN environment'
                ' variable.'
            )

        super().__init__(
            self,
            block_size=block_size,
@@ -29,14 +77,13 @@ class ARMFs(HTTPFileSystem):
            cache_options=cache_options,
            asynchronous=asynchronous,
            loop=loop,
            client_kwargs=client_kwargs,
            **storage_options
        )

    async def _ls(self, path, detail=True, **kwargs):
        kw = self.kwargs.copy()
        kw.update(kwargs)
        url = f'{config.endpoint}{path}'
        url = f'{config.endpoint}/fs{path}'
        if config.debug:
            print(url)

@@ -65,7 +112,7 @@ class ARMFs(HTTPFileSystem):
    def _open(self, path, mode="rb", block_size=None, cache_type=None, cache_options=None, size=None, **kwargs):
        if mode != "rb":
            raise OSError(errno.EROFS)
        url = f'{config.endpoint}/stream{path}'
        url = f'{config.endpoint}/fs/stream{path}'
        block_size = block_size if block_size is not None else self.block_size
        kw = self.kwargs.copy()
        kw["asynchronous"] = self.asynchronous
@@ -83,6 +130,7 @@ class ARMFs(HTTPFileSystem):
                cache_type=cache_type or self.cache_type,
                cache_options=cache_options or self.cache_options,
                loop=self.loop,
                auth=BearerAuth(self.auth_data['token'])
                **kw,
            )

+2 −1
Original line number Diff line number Diff line
@@ -10,13 +10,14 @@ setup(
        'ormar[postgresql]',
        'fastapi',
        'fastapi-versioning',
        'fsspec',
        'fsspec[fusepy, http]',
        'requests',
        'uvicorn',
        'pandas',
        'pyjwt[crypto]',
        'retry',
        'tqdm',
        'xarray',
    ],
    entry_points={
        'fsspec.specs': [