Commit 01f178d3 authored by Peterson, Peter's avatar Peterson, Peter
Browse files

Save the omega angles too

parent c1036f94
Loading
Loading
Loading
Loading
+12 −6
Original line number Diff line number Diff line
@@ -491,15 +491,21 @@ def _to_time_str(value: datetime) -> str:
    return value.strftime("%Y%m%d%H%M")


def _save_data(data: np.ndarray, filename: Path) -> None:
def _save_data(filename: Path, data: np.ndarray, omegas: np.ndarray = None) -> None:
    if data is None:
        raise ValueError("Failed to supply data")
    logger.info(f'saving tiffs to "{filename.parent}"')

    # make sure the directory exists
    if not filename.parent.exists():
        filename.parent.mkdir(parents=True)
    # save the stack of tiffs
    dxchange.write_tiff_stack(data, fname=str(filename))

    # save the angles as a numpy object
    if omegas is not None:
        np.save(file=filename.parent / "omegas.npy", arr=omegas)


class save_data(param.ParameterizedFunction):
    """
@@ -518,6 +524,8 @@ class save_data(param.ParameterizedFunction):
        ``param.Foldername`` will warn if the directory does not already exist.
    name: str
        Used to name file of output, defaults to ``save_data``
    omegas: Array
        Optional for writing out the array of omega angles

    Returns
    -------
@@ -527,6 +535,7 @@ class save_data(param.ParameterizedFunction):
    data = param.Array(doc="Data to save", precedence=1)
    outputbase = param.Foldername(default="/tmp/", doc="radiograph directory")
    name = param.String(default="save_data", doc="name for the radiograph")
    omegas = param.Array(doc="Collection of omega angles")

    def __call__(self, **params):
        """Parse inputs and perform multiple dispatch."""
@@ -543,7 +552,7 @@ class save_data(param.ParameterizedFunction):
        save_dir = params.outputbase / f"{params.name}_{_to_time_str(datetime.now())}"

        # save the data as tiffs
        _save_data(data=params.data, filename=save_dir / params.name)
        _save_data(filename=save_dir / params.name, data=params.data, omegas=params.omegas)

        return save_dir

@@ -591,9 +600,6 @@ class save_checkpoint(param.ParameterizedFunction):
        save_dir = params.outputbase / f"{params.name}_chkpt_{_to_time_str(datetime.now())}"

        # save the data as tiffs
        _save_data(data=params.data, filename=save_dir / params.name)
        # save the angles as a numpy object
        if params.omegas is not None:
            np.save(file=save_dir / "omegas.npy", arr=params.omegas)
        _save_data(filename=save_dir / params.name, data=params.data, omegas=params.omegas)

        return save_dir
+8 −5
Original line number Diff line number Diff line
@@ -285,15 +285,18 @@ def create_fake_data():

@pytest.mark.parametrize("name", ["junk", ""])  # gets default name
def test_save_data(name):
    data = create_fake_data()
    omegas = np.asarray([1.0, 2.0, 3.0])
    # this context will remove directory on exit
    with TemporaryDirectory() as tmpdirname:
        assert tmpdirname
        data = create_fake_data()
        tmpdir = Path(tmpdirname)

        # run the code
        numfiles = 3
        if name:
            outputdir = save_data(data=data, outputbase=tmpdir, name=name)
            outputdir = save_data(data=data, outputbase=tmpdir, name=name, omegas=omegas)
            numfiles += 1
        else:
            outputdir = save_data(data=data, outputbase=tmpdir)
        print(outputdir)
@@ -305,10 +308,10 @@ def test_save_data(name):
            prefix = "save_data_"  # special name

        assert outputdir.name.startswith(prefix), str(outputdir.name)
        check_savefiles(outputdir, prefix)
        check_savefiles(outputdir, prefix, has_omega=bool(name), num_files=numfiles)


def xtest_save_data_subdir():
def test_save_data_subdir():
    name = "subdirtest"
    # this context will remove directory on exit
    with TemporaryDirectory() as tmpdirname:
@@ -320,7 +323,7 @@ def xtest_save_data_subdir():
        assert outputdir.name.startswith(f"{name}_"), str(outputdir)

        # check the result
        check_savefiles(tmpdir, "subdirtest_")
        check_savefiles(outputdir, "subdirtest_")


def test_save_checkpoint():