Commit eaa7b7fa authored by Zhang, Chen's avatar Zhang, Chen
Browse files

add unit test for the wrapper

parent 4d9c1da1
Loading
Loading
Loading
Loading
+30 −2
Original line number Diff line number Diff line
@@ -4,11 +4,12 @@ import pytest
import tomopy
from imars3d.backend.corrections.ring_removal import remove_ring_artifact
from imars3d.backend.corrections.ring_removal import remove_ring_artifact_Ketcham
from imars3d.backend.corrections.ring_removal import bm3d_ring_removal


def get_synthetic_stack(N_omega: int = 181) -> np.ndarray:
def get_synthetic_stack(N_omega: int = 181, size: int = 200) -> np.ndarray:
    omegas = np.linspace(0, np.pi * 2, N_omega)
    shepp3d = tomopy.misc.phantom.shepp3d(size=200)
    shepp3d = tomopy.misc.phantom.shepp3d(size=size)
    # use emission type radiograph to skip the -log step
    projs = tomopy.sim.project.project(shepp3d, omegas, emission=True)
    return projs
@@ -61,5 +62,32 @@ def test_remove_ring_artifact_Ketcham():
    assert err_correction < err_no_correction


def test_bm3d_ring_removal():
    # step_0: prepare synthetic data
    # note: we need a tiny test data as bm3d is very slow
    tomo_ideal = get_synthetic_stack(N_omega=21, size=32)
    # mimic the absorption graph
    tomo_ideal = 1.0 - tomo_ideal / tomo_ideal.max()
    # add white noise
    mean = 0.0
    sigma = 1e-2
    tomo_noisy = tomo_ideal + np.random.normal(mean, sigma, tomo_ideal.shape)
    tomo_noisy = np.nan_to_num(np.log(tomo_noisy), nan=0.0, posinf=0.0, neginf=0.0)
    tomo_ideal = np.nan_to_num(np.log(tomo_ideal), nan=0.0, posinf=0.0, neginf=0.0)
    # the streak in sinogram cannot exceed +-0.05
    tomo_noisy_ringy = np.array(tomo_noisy)
    mean = 0.0
    sigma = 1e-2  # use smaller value to ensure the streak is not overwhelming
    for i in range(tomo_noisy_ringy.shape[1]):
        streak_noise_component = np.random.normal(mean, sigma, tomo_noisy_ringy.shape[2])
        tomo_noisy_ringy[:, i, :] = tomo_noisy_ringy[:, i, :] + streak_noise_component
    # step_1: run correction
    tomo_corrected = bm3d_ring_removal(arrays=tomo_noisy_ringy)
    # step_2: verify
    score_ref = np.median(np.absolute(tomo_noisy_ringy - tomo_ideal))
    score_corr = np.median(np.absolute(tomo_corrected - tomo_ideal))
    assert score_corr < score_ref


if __name__ == "__main__":
    pytest.main([__file__])