Commit d6ce4465 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

allow passing rotation center for tilt

parent 6d79fbf9
Loading
Loading
Loading
Loading
+31 −4
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ import param
import multiprocessing
from imars3d.backend.util.functions import clamp_max_workers
import numpy as np
from typing import Tuple
from typing import Tuple, Union, Optional
from functools import partial
from scipy.optimize import minimize_scalar
from scipy.optimize import OptimizeResult
@@ -103,6 +103,7 @@ def calculate_dissimilarity(
    tilt: float,
    image0: np.ndarray,
    image1: np.ndarray,
    center: Optional[Tuple[Union[float, int], Union[float, int]]] = None,
) -> float:
    """Calculate the dissimilarity between two images with given tilt.

@@ -119,6 +120,9 @@ def calculate_dissimilarity(
    image1:
        The second image for comparison, which is often the radiograph taken at
        omega + 180 deg
    center:
        The center of the rotation axis, default is None, which means the center
        of the image. This will be passed to the rotation function from skimage.

    Returns
    -------
@@ -168,6 +172,7 @@ def calculate_dissimilarity(
        resize=True,
        preserve_range=True,
        order=1,  # use default bi-linear interpolation for rotation
        center=center,
    )
    # since 180 is flipped, tilting back -2 deg of the original img180 means tilting +2 deg
    # of the flipped one
@@ -178,6 +183,7 @@ def calculate_dissimilarity(
        resize=True,
        preserve_range=True,
        order=1,  # use default bi-linear interpolation for rotation
        center=center,
    )

    # p-norm
@@ -198,6 +204,7 @@ def calculate_tilt(
    image180: np.ndarray,
    low_bound: float = -5.0,
    high_bound: float = 5.0,
    center: Optional[Tuple[Union[float, int], Union[float, int]]] = None,
) -> OptimizeResult:
    """
    Use optimization to find the in-plane tilt angle.
@@ -214,13 +221,16 @@ def calculate_tilt(
        The lower bound of the tilt angle search space
    high_bound:
        The upper bound of the tilt angle search space
    center:
        The center of the rotation axis, default is None, which means the center
        of the image. This will be passed to the rotation function from skimage.

    Returns
    -------
        The optimization results from scipy.optimize.minimize_scalar
    """
    # make the error function
    err_func = partial(calculate_dissimilarity, image0=image0, image1=image180)
    err_func = partial(calculate_dissimilarity, image0=image0, image1=image180, center=center)
    # use bounded uni-variable optimizer to locate the tilt angle that minimize
    # the dissimilarity of the 180 deg pair
    res = minimize_scalar(
@@ -249,6 +259,9 @@ class tilt_correction(param.ParameterizedFunction):
    cut_off_angle_deg: float
        The angle in degrees to cut off the rotation axis tilt correction, i.e.
        skip applying tilt correction for tilt angles that are too small.
    center: Any
        The center of the rotation axis, default is None, which means the center
        of the image. This will be passed to the rotation function from skimage.
    max_workers:
        Number of cores to use for parallel median filtering, default is 0,
        which means using all available cores.
@@ -275,6 +288,10 @@ class tilt_correction(param.ParameterizedFunction):
        default=2.0,
        doc="The angle in degrees to cut off the rotation axis tilt correction, i.e. skip applying tilt correction for tilt angles that are too small.",
    )
    center = param.Parameter(
        default=None,
        doc="The center of the rotation axis, default is None, which means the center of the image. This will be passed to the rotation function from skimage.",
    )
    # NOTE:
    # The front and backend are sharing the same computing unit, therefore we can
    # set a hard cap on the max_workers.
@@ -329,6 +346,7 @@ class tilt_correction(param.ParameterizedFunction):
                    calculate_tilt,
                    low_bound=params.low_bound,
                    high_bound=params.high_bound,
                    center=params.center,
                ),
                [shm_arrays[il] for il in idx_lowrange],
                [shm_arrays[ih] for ih in idx_highrange],
@@ -349,6 +367,7 @@ class tilt_correction(param.ParameterizedFunction):
            corrected_array = apply_tilt_correction(
                arrays=params.arrays,
                tilt=tilt,
                center=params.center,
                max_workers=self.max_workers,
            )
        return corrected_array
@@ -366,6 +385,8 @@ class apply_tilt_correction(param.ParameterizedFunction):
        The array for tilt correction
    tilt: float
        The rotation axis tilt angle in degrees
    center: Any
        The center of the rotation axis, default is None, which means the center
    max_workers: int
        Number of cores to use for parallel median filtering, default is 0, which means using all available cores.
    tqdm_class: panel.widgets.Tqdm
@@ -379,6 +400,10 @@ class apply_tilt_correction(param.ParameterizedFunction):

    arrays = param.Array(doc="The array for tilt correction", default=None)
    tilt = param.Number(doc="The rotation axis tilt angle in degrees", default=None)
    center = param.Parameter(
        default=None,
        doc="The center of the rotation axis, default is None, which means the center of the image. This will be passed to the rotation function from skimage.",
    )
    # NOTE:
    # The front and backend are sharing the same computing unit, therefore we can
    # set a hard cap on the max_workers.
@@ -406,7 +431,9 @@ class apply_tilt_correction(param.ParameterizedFunction):
        # dimensionality check
        if params.arrays.ndim == 2:
            logger.info(f"2D image detected, applying tilt correction with tilt = {params.tilt:.3f} deg")
            corrected_array = rotate(params.arrays, -params.tilt, resize=False, preserve_range=True)
            corrected_array = rotate(
                params.arrays, -params.tilt, resize=False, preserve_range=True, center=params.center
            )
        elif params.arrays.ndim == 3:
            logger.info(f"3D array detected, applying tilt correction with tilt = {params.tilt:.3f} deg")
            with SharedMemoryManager() as smm:
@@ -420,7 +447,7 @@ class apply_tilt_correction(param.ParameterizedFunction):
                if params.tqdm_class:
                    kwargs["tqdm_class"] = params.tqdm_class
                rst = process_map(
                    partial(rotate, angle=-params.tilt, resize=False, preserve_range=True),
                    partial(rotate, angle=-params.tilt, resize=False, preserve_range=True, center=params.center),
                    [shm_arrays[idx] for idx in range(params.arrays.shape[0])],
                    **kwargs,
                )