Unverified Commit 0b2e4fb1 authored by Jose Borreguero's avatar Jose Borreguero Committed by GitHub
Browse files

fix integration backend test_run (#257)

* fix test_run fetching of output
* use np.testing.assert_allclose
* change omegas to rot_angles
parent a463ce58
Loading
Loading
Loading
Loading
+13 −13
Original line number Diff line number Diff line
@@ -161,7 +161,7 @@ class load_data(param.ParameterizedFunction):
            logger.warning("No valid signature found, need to specify either files or dir")
            raise ValueError("No valid signature found, need to specify either files or dir")

        # extracting omegas from
        # extracting rotational angles from
        # 1. filename
        # 2. metadata (only possible for Tiff)
        rot_angles = _extract_rotation_angles(ct_files)
@@ -491,7 +491,7 @@ def _to_time_str(value: datetime) -> str:
    return value.strftime("%Y%m%d%H%M")


def _save_data(filename: Path, data: np.ndarray, omegas: np.ndarray = None) -> None:
def _save_data(filename: Path, data: np.ndarray, rot_angles: np.ndarray = None) -> None:
    if data is None:
        raise ValueError("Failed to supply data")
    logger.info(f'saving tiffs to "{filename.parent}"')
@@ -503,8 +503,8 @@ def _save_data(filename: Path, data: np.ndarray, omegas: np.ndarray = None) -> N
    dxchange.write_tiff_stack(data, fname=str(filename))

    # save the angles as a numpy object
    if omegas is not None:
        np.save(file=filename.parent / "omegas.npy", arr=omegas)
    if rot_angles is not None:
        np.save(file=filename.parent / "rot_angles.npy", arr=rot_angles)


class save_data(param.ParameterizedFunction):
@@ -524,8 +524,8 @@ class save_data(param.ParameterizedFunction):
        ``param.Foldername`` will warn if the directory does not already exist.
    name: str
        Used to name file of output, defaults to ``save_data``
    omegas: Array
        Optional for writing out the array of omega angles
    rot_angles: Array
        Optional for writing out the array of rotational (omega) angles

    Returns
    -------
@@ -535,7 +535,7 @@ class save_data(param.ParameterizedFunction):
    data = param.Array(doc="Data to save", precedence=1)
    outputbase = param.Foldername(default="/tmp/", doc="radiograph directory")
    name = param.String(default="save_data", doc="name for the radiograph")
    omegas = param.Array(doc="Collection of omega angles")
    rot_angles = param.Array(doc="Collection of omega angles")

    def __call__(self, **params):
        """Parse inputs and perform multiple dispatch."""
@@ -549,10 +549,10 @@ class save_data(param.ParameterizedFunction):
        if params.data is None:
            raise ValueError("Did not supply data")

        save_dir = params.outputbase / f"{params.name}_{_to_time_str(datetime.now())}"
        save_dir = Path(params.outputbase) / f"{params.name}_{_to_time_str(datetime.now())}"

        # save the data as tiffs
        _save_data(filename=save_dir / params.name, data=params.data, omegas=params.omegas)
        _save_data(filename=save_dir / params.name, data=params.data, rot_angles=params.rot_angles)

        return save_dir

@@ -574,8 +574,8 @@ class save_checkpoint(param.ParameterizedFunction):
        ``param.Foldername`` will warn if the directory does not already exist.
    name: str
        Used to name file of output, defaults to output_{datetime}
    omegas: Array
        Optional for writing out the array of omega angles
    rot_angles: Array
        Optional for writing out the array of rotational (omega) angles

    Returns
    -------
@@ -586,7 +586,7 @@ class save_checkpoint(param.ParameterizedFunction):
    outputbase = param.Foldername(default="/tmp/", doc="directory checkpoint should exist in")

    name = param.String(default="*", doc="name for the checkpoint")
    omegas = param.Array(doc="Collection of omega angles")
    rot_angles = param.Array(doc="Collection of rotational (omega) angles")

    def __call__(self, **params):
        """Parse inputs and perform multiple dispatch."""
@@ -600,6 +600,6 @@ class save_checkpoint(param.ParameterizedFunction):
        save_dir = params.outputbase / f"{params.name}_chkpt_{_to_time_str(datetime.now())}"

        # save the data as tiffs
        _save_data(filename=save_dir / params.name, data=params.data, omegas=params.omegas)
        _save_data(filename=save_dir / params.name, data=params.data, rot_angles=params.rot_angles)

        return save_dir
+7 −4
Original line number Diff line number Diff line
@@ -15,6 +15,8 @@ from copy import deepcopy
import json
import numpy as np
from pathlib import Path
import re
from typing import Callable


@pytest.fixture(scope="module")
@@ -52,17 +54,18 @@ class TestWorkflowEngineAuto:
        assert workflow.config == config

    @pytest.mark.datarepo
    def test_run(self, config, THIS_DATA_DIR, cleanfile):
    def test_run(self, config: dict, THIS_DATA_DIR: Path, cleanfile: Callable, caplog):
        cleanfile(self.outputdir)
        workflow = WorkflowEngineAuto(config)
        expected_slice_300 = np.load(str(THIS_DATA_DIR / "expected_slice_300.npy"))
        workflow.run()
        # extract slice and crop to region of interest
        outdir = Path(self.outputdir)
        outfiles = [outdir / item for item in outdir.glob("test*.tiff")]
        tiff_dir = re.search(r'saving tiffs to "([-/\.\w]+)"', caplog.text).groups()[0]
        assert Path(tiff_dir).exists()
        outfiles = [str(tiff_file) for tiff_file in Path(tiff_dir).glob("save_data_*.tiff")]
        result = load_images(outfiles, desc="test", max_workers=clamp_max_workers(None), tqdm_class=None)
        slice_300 = crop_roi(result[300])
        np.allclose(slice_300, expected_slice_300)
        np.testing.assert_allclose(slice_300, expected_slice_300, atol=1.0e-7)

    def test_no_config(self):
        with pytest.raises(TypeError):
+3 −3
Original line number Diff line number Diff line
@@ -270,7 +270,7 @@ def check_savefiles(direc: Path, prefix: str, num_files: int = 3, has_omega=Fals
    for filepath in filepaths:
        print(filepath)
        assert filepath.is_file()
        if has_omega and filepath.name == "omegas.npy":
        if has_omega and filepath.name == "rot_angles.npy":
            continue
        assert filepath.suffix == ".tiff"
        # the names are zero-padded
@@ -295,7 +295,7 @@ def test_save_data(name):
        # run the code
        numfiles = 3
        if name:
            outputdir = save_data(data=data, outputbase=tmpdir, name=name, omegas=omegas)
            outputdir = save_data(data=data, outputbase=tmpdir, name=name, rot_angles=omegas)
            numfiles += 1
        else:
            outputdir = save_data(data=data, outputbase=tmpdir)
@@ -348,7 +348,7 @@ def test_save_checkpoint():
        omegas = np.asarray([1.0, 2.0, 3.0])
        tmpdir = Path(tmpdirname)

        outputdir = save_checkpoint(data=data, outputbase=tmpdir, name=name, omegas=omegas)
        outputdir = save_checkpoint(data=data, outputbase=tmpdir, name=name, rot_angles=omegas)

        assert outputdir.name.startswith(f"{name}_chkpt_"), str(outputdir)
        # check the tiffs result