Unverified Commit 201094c7 authored by Zhang, Chen's avatar Zhang, Chen Committed by GitHub
Browse files

Merge pull request #266 from ornlneutronimaging/IMG448_add_bm3d

[IMG448] Add thin wrapper around new ring removal method based on BM3D
parents f945b001 d7c2dc2f
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -29,3 +29,4 @@ dependencies:
    - versioningit
    - check-wheel-contents
    - pytest-playwright
    - bm3d-streak-removal
+140 −0
Original line number Diff line number Diff line
@@ -6,6 +6,11 @@ import param
from imars3d.backend.util.functions import clamp_max_workers
import scipy
import numpy as np

try:
    import bm3d_streak_removal as bm3dsr
except ImportError:
    bm3dsr = None
from multiprocessing.managers import SharedMemoryManager
from tqdm.contrib.concurrent import process_map
from functools import partial
@@ -13,6 +18,141 @@ from functools import partial
logger = logging.getLogger(__name__)


class bm3d_ring_removal(param.ParameterizedFunction):
    """
    Remove ring artifact from sinograms using BM3D method.

    This method requires BM3D suite, which can be installed by ``pip install bm3d_streak_removal``.

    ref: `10.1107/S1600577521001910 <http://doi.org/10.1107/S1600577521001910>`_

    Parameters
    ----------
    arrays: np.ndarray
        Input radiograph stack.
    extreme_streak_iterations: int
        Number of iterations for extreme streak attenuation.
    extreme_detect_lambda: float
        Consider streaks which are stronger than lambda * local_std as extreme.
    extreme_detect_size: int
        Half window size for extreme streak detection -- total (2*s + 1).
    extreme_replace_size: int
        Half window size for extreme streak replacement -- total (2*s + 1).
    max_bin_iter_horizontal: int
        The number of total horizontal scales (counting the full scale).
    bin_vertical: int
        The factor of vertical binning, e.g. bin_vertical=32 would perform denoising in 1/32th of the original vertical size.
    filter_strength: float
        Strength of BM4D denoising (>0), where 1 is the standard application, >1 is stronger, and <1 is weaker.
    use_slices: bool
        If True, the sinograms will be split horizontally across each binning iteration into overlapping.
    slice_sizes: list
        A list of horizontal sizes for use of the slicing if use_slices=True. By default, slice size is either 39 pixels or 1/5th of the total width of the current iteration, whichever is larger.
    slice_step_sizes: list
        List of number of pixels between slices obtained with use_slices=True, one for each binning iteration. By default 1/4th of the corresponding slice size.
    denoise_indices: list
        Indices of sinograms to denoise; by default, denoises the full stack provided.

    Returns
    -------
        Radiograph stack with ring artifact removed.

    Notes
    -----
    1. The parallel processing is handled at the bm3d level, and it is an intrinsic
    slow correction algorithm running on CPU.
    2. The underlying BM3D library uses stdout to print progress instead of a progress
    bar.
    """

    arrays = param.Array(doc="Input radiograph stack.", default=None)
    # parameters passed to bm3dsr.extreme_streak_attenuation
    extreme_streak_iterations = param.Integer(default=3, doc="Number of iterations for extreme streak attenuation.")
    extreme_detect_lambda = param.Number(
        default=4.0,
        doc="Consider streaks which are stronger than lambda * local_std as extreme.",
    )
    extreme_detect_size = param.Integer(
        default=9,
        doc="Half window size for extreme streak detection -- total (2*s + 1).",
    )
    extreme_replace_size = param.Integer(
        default=2,
        doc="Half window size for extreme streak replacement -- total (2*s + 1).",
    )
    # parameters passed to bm3dsr.multiscale_streak_removal
    max_bin_iter_horizontal = param.Integer(
        default=0,
        doc="The number of total horizontal scales (counting the full scale).",
        bounds=(0, None),
    )
    bin_vertical = param.Integer(
        default=0,
        doc="The factor of vertical binning, e.g. bin_vertical=32 would perform denoising in 1/32th of the original vertical size.",
        bounds=(0, None),
    )
    filter_strength = param.Number(
        default=1.0,
        doc="Strength of BM4D denoising (>0), where 1 is the standard application, >1 is stronger, and <1 is weaker.",
        bounds=(0, None),
    )
    use_slices = param.Boolean(
        default=True,
        doc="If True, the sinograms will be split horizontally across each binning iteration into overlapping.",
    )
    slice_sizes = param.List(
        default=None,
        doc="A list of horizontal sizes for use of the slicing if use_slices=True. By default, slice size is either 39 pixels or 1/5th of the total width of the current iteration, whichever is larger.",
    )
    slice_step_sizes = param.List(
        default=None,
        doc="List of number of pixels between slices obtained with use_slices=True, one for each binning iteration. By default 1/4th of the corresponding slice size.",
    )
    denoise_indices = param.List(
        default=None,
        doc="Indices of sinograms to denoise; by default, denoises the full stack provided.",
    )
    # note: we are skipping the bm3d_profile_obj parameter as bm3d is not explicitly used in iMars3D.

    def __call__(self, **params):
        """See class level documentation for help."""
        if not bm3dsr:
            logger.warning("To use method, make sure to install bm3d_streak_removal package via pip.")
            raise RuntimeError("BM3D suite not installed, please install with pip install bm3d_streak_removal")
        else:
            logger.info("Executing Filter: Remove Ring Artifact with BM3D")
        _ = self.instance(**params)
        params = param.ParamOverrides(self, params)
        # mangle parameters
        if params.max_bin_iter_horizontal == 0:
            params.max_bin_iter_horizontal = "auto"
        if params.bin_vertical == 0:
            params.bin_vertical = "auto"
        # step 1: extreme streak attenuation
        logger.debug("Perform extreme streak attenuation")
        param.arrays = bm3dsr.extreme_streak_attenuation(
            data=params.arrays,
            extreme_streak_iterations=params.extreme_streak_iterations,
            extreme_detect_lambda=params.extreme_detect_lambda,
            extreme_detect_size=params.extreme_detect_size,
            extreme_replace_size=params.extreme_replace_size,
        )
        # step 2: multiscale streak removal
        logger.debug("Perform multiscale streak removal")
        param.arrays = bm3dsr.multiscale_streak_removal(
            data=params.arrays,
            max_bin_iter_horizontal=params.max_bin_iter_horizontal,
            bin_vertical=params.bin_vertical,
            filter_strength=params.filter_strength,
            use_slices=params.use_slices,
            slice_sizes=params.slice_sizes,
            slice_step_sizes=params.slice_step_sizes,
            denoise_indices=params.denoise_indices,
        )
        logger.info("FINISHED Executing Filter: Remove Ring Artifact")
        return param.arrays


class remove_ring_artifact(param.ParameterizedFunction):
    """
    Remove ring artifact from radiograph stack using Ketcham method.
+36 −2
Original line number Diff line number Diff line
@@ -4,11 +4,17 @@ import pytest
import tomopy
from imars3d.backend.corrections.ring_removal import remove_ring_artifact
from imars3d.backend.corrections.ring_removal import remove_ring_artifact_Ketcham
from imars3d.backend.corrections.ring_removal import bm3d_ring_removal

try:
    import bm3d_streak_removal as bm3dsr
except ImportError:
    bm3dsr = None

def get_synthetic_stack(N_omega: int = 181) -> np.ndarray:

def get_synthetic_stack(N_omega: int = 181, size: int = 200) -> np.ndarray:
    omegas = np.linspace(0, np.pi * 2, N_omega)
    shepp3d = tomopy.misc.phantom.shepp3d(size=200)
    shepp3d = tomopy.misc.phantom.shepp3d(size=size)
    # use emission type radiograph to skip the -log step
    projs = tomopy.sim.project.project(shepp3d, omegas, emission=True)
    return projs
@@ -61,5 +67,33 @@ def test_remove_ring_artifact_Ketcham():
    assert err_correction < err_no_correction


@pytest.mark.skipif(not bm3dsr, reason="bm3d not installed, skipping test.")
def test_bm3d_ring_removal():
    # step_0: prepare synthetic data
    # note: we need a tiny test data as bm3d is very slow
    tomo_ideal = get_synthetic_stack(N_omega=21, size=32)
    # mimic the absorption graph
    tomo_ideal = 1.0 - tomo_ideal / tomo_ideal.max()
    # add white noise
    mean = 0.0
    sigma = 1e-2
    tomo_noisy = tomo_ideal + np.random.normal(mean, sigma, tomo_ideal.shape)
    tomo_noisy = np.nan_to_num(np.log(tomo_noisy), nan=0.0, posinf=0.0, neginf=0.0)
    tomo_ideal = np.nan_to_num(np.log(tomo_ideal), nan=0.0, posinf=0.0, neginf=0.0)
    # the streak in sinogram cannot exceed +-0.05
    tomo_noisy_ringy = np.array(tomo_noisy)
    mean = 0.0
    sigma = 1e-2  # use smaller value to ensure the streak is not overwhelming
    for i in range(tomo_noisy_ringy.shape[1]):
        streak_noise_component = np.random.normal(mean, sigma, tomo_noisy_ringy.shape[2])
        tomo_noisy_ringy[:, i, :] = tomo_noisy_ringy[:, i, :] + streak_noise_component
    # step_1: run correction
    tomo_corrected = bm3d_ring_removal(arrays=tomo_noisy_ringy)
    # step_2: verify
    score_ref = np.median(np.absolute(tomo_noisy_ringy - tomo_ideal))
    score_corr = np.median(np.absolute(tomo_corrected - tomo_ideal))
    assert score_corr < score_ref


if __name__ == "__main__":
    pytest.main([__file__])