Commit 31cf8bf3 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

Merge branch 'qa'

parents 174d4316 2fb979b2
Loading
Loading
Loading
Loading
+43 −9
Original line number Diff line number Diff line
@@ -104,14 +104,14 @@ class load_data(param.ParameterizedFunction):

    Notes
    -----
        There are two main signatures to load the data:
        There are three main signatures to load the data:
        1. load_data(ct_files=ctfs, ob_files=obfs, dc_files=dcfs)
        2. load_data(ct_dir=ctdir, ob_dir=obdir, dc_dir=dcdir)
        3. load_data(ct_dir=ctdir, ob_files=obfs, dc_files=dcfs)

        The two signatures are mutually exclusive, and dc_files and dc_dir are optional
        in both cases as some experiments do not have dark current measurements.
        In all signatures dc_files and dc_dir are optional

        The fnmatch selectors are applicable in both signature, which help to down-select
        The fnmatch selectors are applicable in all signature, which help to down-select
        files if needed. Default is set to "*", which selects everything.
        Also, if ob_fnmatch and dc_fnmatch are set to "None" in the second signature call, the
        data loader will attempt to read the metadata embedded in the first ct file to find obs
@@ -157,9 +157,43 @@ class load_data(param.ParameterizedFunction):
        #    use set to simplify call signature checking
        sigs = set([k.split("_")[-1] for k in params.keys() if "fnmatch" not in k])
        ref = {"files", "dir"}
        if sigs.intersection(ref) == {"files", "dir"}:
            logger.error("Files and dir cannot be used at the same time")
            raise ValueError("Mix usage of allowed signature.")

        if ("ct_dir" in params.keys()) and ("ob_files" in params.keys()):
            logger.debug("Load ct by directory, ob and dc (if any) by files")
            ct_dir = params.get("ct_dir")
            if not Path(ct_dir).exists():
                logger.error(f"ct_dir {ct_dir} does not exist.")
                raise ValueError("ct_dir does not exist.")
            else:
                ct_dir = Path(ct_dir)

            # gather the ct_files
            ct_fnmatch = params.get("ct_fnmatch", "*")
            ct_files = ct_dir.glob(ct_fnmatch)
            ct_files = list(map(str, ct_files))
            ct_files.sort()

            ob_files = (params.get("ob_files"),)
            dc_files = (params.get("dc_files", []),)  # it is okay to skip dc

            ob_files = ob_files[0]
            dc_files = dc_files[0]

            ct, ob, dc = _load_by_file_list(
                ct_files=ct_files,
                ob_files=ob_files,
                dc_files=dc_files,  # it is okay to skip dc
                ct_fnmatch=params.get("ct_fnmatch", "*"),  # incase None got leaked here
                ob_fnmatch=params.get("ob_fnmatch", "*"),
                dc_fnmatch=params.get("dc_fnmatch", "*"),
                max_workers=self.max_workers,
                tqdm_class=params.tqdm_class,
            )

        elif ("ct_files" in params.keys()) and ("ob_dir" in params.keys()):
            logger.error("ct_files and ob_dir mixed not allowed!")
            raise ValueError("Mix signatures (ct_files, ob_dir) not allowed!")

        elif sigs.intersection(ref) == {"files"}:
            logger.debug("Load by file list")
            ct, ob, dc = _load_by_file_list(
@@ -227,8 +261,8 @@ def _forgiving_reader(
    """
    try:
        return reader(filename)
    except Exception:
        logger.error(f"Cannot read {filename}, skipping.")
    except Exception as e:
        logger.error(f"While reading {filename}, the following error occurred: {e}")
        return None


+7 −5
Original line number Diff line number Diff line
@@ -97,6 +97,8 @@ def test_load_data(
    # error_0: incorrect input argument types
    with pytest.raises(ValueError):
        load_data(ct_files=1, ob_files=[], dc_files=[])
        load_data(ct_dir=1, ob_files=[])
        load_data(ct_files=[], ob_dir="/tmp")
        load_data(ct_files=[], ob_files=[], dc_files=[], ct_fnmatch=1)
        load_data(ct_files=[], ob_files=[], dc_files=[], ob_fnmatch=1)
        load_data(ct_files=[], ob_files=[], dc_files=[], dc_fnmatch=1)
@@ -106,16 +108,16 @@ def test_load_data(
    # error_1: out of bounds value
    with pytest.raises(ValueError):
        load_data(ct_files=[], ob_files=[], dc_files=[], max_workers=-1)
    # error_2: mix usage of function signature 1 and 2
    with pytest.raises(ValueError):
        load_data(ct_files=[], ob_files=[], dc_files=[], ct_dir="/tmp", ob_dir="/tmp")
    # error_3: no valid signature found
    with pytest.raises(ValueError):
        load_data(ct_fnmatch=1)
    # case_0: load data from file list
    # case_0: load ct from directory, ob and dc from files
    rst = load_data(ct_dir="/tmp", ob_files=["3", "4"], dc_files=["5", "6"])
    np.testing.assert_almost_equal(np.array(rst).flatten(), np.arange(1, 5, dtype=float))
    # case_1: load data from file list
    rst = load_data(ct_files=["1", "2"], ob_files=["3", "4"], dc_files=["5", "6"])
    np.testing.assert_almost_equal(np.array(rst).flatten(), np.arange(1, 5, dtype=float))
    # case_1: load data from given directory
    # case_2: load data from given directory
    rst = load_data(ct_dir="/tmp", ob_dir="/tmp", dc_dir="/tmp")
    np.testing.assert_almost_equal(np.array(rst).flatten(), np.arange(1, 5, dtype=float))