diff --git a/docs/source/conf.py b/docs/source/conf.py index 62294ce6ed4c24349c7cecc643be3fc6cbeba4a9..e8bf4b83a568abbe700f9c1b084fd15e7805114d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,7 +62,11 @@ mathjax3_config = { "tex": { "macros": { "R": r"\mathbb{R}", + "diverg": r"\operatorname{div}", + "curl": r"\operatorname{curl}", + "argmin": r"\operatorname{argmin}", "Angstrom": r"\text{Å}", + "atan": r"\operatorname{tan}^{-1}", } }, } diff --git a/docs/source/modules/data/nion.rst b/docs/source/modules/data/nion.rst new file mode 100644 index 0000000000000000000000000000000000000000..dc26840da3228a0e2390c037b0fc699e0c37c45e --- /dev/null +++ b/docs/source/modules/data/nion.rst @@ -0,0 +1,8 @@ +data.nion +========= + +.. automodule:: ptychopath.data.nion + :members: + :special-members: + :exclude-members: __dict__,__weakref__,prepare_data,setup,test_dataloader,train_dataloader,val_dataloader + :show-inheritance: diff --git a/docs/source/modules/microscope.rst b/docs/source/modules/microscope.rst index e26e1ffb8d11116fec07fff1441e3564dd1fdd4c..f0e41b7956b575d8627e9e027403c5129c53d4a7 100644 --- a/docs/source/modules/microscope.rst +++ b/docs/source/modules/microscope.rst @@ -10,6 +10,5 @@ microscope .. toctree:: :glob: :maxdepth: 2 - :caption: Python API Reference microscope/* diff --git a/docs/source/modules/microscope/calibration.rst b/docs/source/modules/microscope/calibration.rst new file mode 100644 index 0000000000000000000000000000000000000000..898eb66e7303712fd8f49493b2306124f40158b8 --- /dev/null +++ b/docs/source/modules/microscope/calibration.rst @@ -0,0 +1,14 @@ +microscope.calibration +====================== + +.. automodule:: ptychopath.microscope.calibration + :members: + :special-members: + :exclude-members: __dict__,__weakref__ + :show-inheritance: + +.. toctree:: + :glob: + :maxdepth: 2 + + calibration/* diff --git a/docs/source/modules/microscope/calibration/aperture.rst b/docs/source/modules/microscope/calibration/aperture.rst new file mode 100644 index 0000000000000000000000000000000000000000..b3302b20755fe3cc9fbc000846ab205db965be11 --- /dev/null +++ b/docs/source/modules/microscope/calibration/aperture.rst @@ -0,0 +1,8 @@ +microscope.calibration.aperture +=============================== + +.. automodule:: ptychopath.microscope.calibration.aperture + :members: + :special-members: + :exclude-members: __dict__,__weakref__ + :show-inheritance: diff --git a/docs/source/modules/microscope/calibration/dpc.rst b/docs/source/modules/microscope/calibration/dpc.rst new file mode 100644 index 0000000000000000000000000000000000000000..0f4629ee2d5e85b78e8da1f43f23d6c8aa523532 --- /dev/null +++ b/docs/source/modules/microscope/calibration/dpc.rst @@ -0,0 +1,8 @@ +microscope.calibration.dpc +========================== + +.. automodule:: ptychopath.microscope.calibration.dpc + :members: + :special-members: + :exclude-members: __dict__,__weakref__ + :show-inheritance: diff --git a/docs/source/modules/regrid.rst b/docs/source/modules/regrid.rst new file mode 100644 index 0000000000000000000000000000000000000000..41e134a01f607e7d313bc62695bd9a036878ed65 --- /dev/null +++ b/docs/source/modules/regrid.rst @@ -0,0 +1,8 @@ +regrid +====== + +.. automodule:: ptychopath.regrid + :members: + :special-members: + :exclude-members: __dict__,__weakref__,__init__ + :show-inheritance: diff --git a/docs/source/refs.bib b/docs/source/refs.bib index d8c1d78ef4ad7bbdfc9e7ab93d53dca29e9f1cf5..4ac932fad5200f7d228f6c198432a9518e1b718b 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -22,3 +22,19 @@ year={2015}, publisher={Elsevier} } + +@article{ishizuka17, + author = {Ishizuka, Akimitsu and Oka, Masaaki and Seki, Takehito and Shibata, Naoya and Ishizuka, Kazuo}, + title = "{Boundary-artifact-free determination of potential distribution from differential phase contrast signals}", + journal = {Microscopy}, + volume = {66}, + number = {6}, + pages = {397-405}, + year = {2017}, + month = {08}, + abstract = "{The differential phase contrast (DPC) imaging in STEM was mainly used for a study of magnetic material in a medium resolution. An ideal DPC signals give the center of mass of the diffraction pattern, which is proportional to an electric field. Recently, the possibility of the DPC imaging at atomic resolution was demonstrated. Thus, the DPC imaging opens up the possibility to observe the object phase that is proportional to the electrostatic potential.In this report we investigate the numerical procedures to obtain the object phase from the two perpendicular DPC signals. Specifically, we demonstrate that the discrete cosine transform (DCT) is the method to solve the Poisson equation, since we can use the Neumann boundary condition directly specified by the DPC signals. Furthermore, based on the fast Fourier transform (FFT) of an extended DPC signal we introduce the scheme that gives an equivalent result that is obtained with the DCT. The results obtained with the DCT and extended FFT method are superior to the results obtained with commonly used FFT. In addition, we develop real-time integration schemes that update the result with the progress of the scan. Our real-time integration gives the reasonable result, and can be used in a view mode. We demonstrate that our numerical procedures work excellently with the experimental DPC signals obtained from SrTiO3 single crystal.}", + issn = {2050-5698}, + doi = {10.1093/jmicro/dfx032}, + url = {https://doi.org/10.1093/jmicro/dfx032}, + eprint = {https://academic.oup.com/jmicro/article-pdf/66/6/397/22136416/dfx032.pdf}, +} diff --git a/docs/source/topics/bias.rst b/docs/source/topics/bias.rst new file mode 100644 index 0000000000000000000000000000000000000000..c1c4b01441f3e2f72021eba08690f3b942f7d69d --- /dev/null +++ b/docs/source/topics/bias.rst @@ -0,0 +1,104 @@ +====================== +Handling external bias +====================== + +In our experiments, we use a biasing sample holder to apply an external +electrical bias. This influences our computation of the :term:`multislice` +method since it means we are adding a term to $U(x)$ that is linear in $x$, of +the form $u\cdot x$ for some vector $u\in\R^3$. + +Here we will follow :cite:t:`narangifard2013` chapter 3, but instead of +propagating in free space, we will assume a constant bias. All equation numbers +refer to that reference. + +We start by inserting a linear potential into Eq. 3.14: + +.. math:: + -\left(\nabla_\perp^2 + x\cdot u\right)\phi(x) = 2ik\frac{\partial}{\partial z} \phi(x) + :label: helmholtz + +where $\nabla_\perp^2$ is the 2D Laplacian operator in the XY plane and $k$ is the wavenumber of the incident electron. + + +This equation is an approximation of the Schrödinger equation (Helmholtz equation 3.3) that comes from the assuming the condition in 3.13, namely that $z$ derivatives of $\phi$ are much larger in magnitude than $z$ second derivatives, so that we can ignore second derivatives in $z$. + +Solution of the Helmholtz equation +================================== + +We propose a solution of the following form + +.. math:: + \phi(x) = \exp\left(i\left(k'+\frac{z}{2k}u'\right)\cdot x'\right)\exp\left(i\left(k_z' + \frac{z}{4k}u_z\right)z\right) + +where we have introduced the notation $u'=(u_x,u_y)$ and $k'=(k_x', k_y')$ along with $k_z'$ are (free, so far) parameters in our solution. + + +We now insert this form for $\phi$ into :eq:`helmholtz` to verify that it is a solution. Refer to page 18 of :cite:t:`narangifard2013` to confirm that + +.. math:: + -\nabla_\perp^2 \phi(x) = \left|k' + \frac{z}{2k}u'\right|^2 \phi(x). + +The $z$ derivative is also readily computed: + +.. math:: + \frac{\partial}{\partial z}\phi(x) = i\left(\frac{1}{2k}u'\cdot x' + \left(k_z'+\frac{z}{2k} u_z\right)\right)\phi(x). + +which simplifies to + +.. math:: + \frac{\partial}{\partial z}\phi(x) = i\left(\frac{1}{2k}u\cdot x +k_z'\right)\phi(x). + +Inserting these into :eq:`helmholtz` we have the condition + +.. math:: + \left|k' + \frac{z}{2k}u'\right|^2 - x\cdot u = -2k\left(\frac{1}{2k}u\cdot x + k_z'\right), + +which after cancellation becomes + +.. math:: + \left|k' + \frac{z}{2k}u'\right|^2 = -2k k_z' \iff k_z' = \frac{-\left|k' + \frac{z}{2k}u'\right|^2}{2k}. + :label: kz + +Under the condition in :eq:`kz`, our proposal solution does indeed solve our Helmholtz-type equation. + +Note that :cite:t:`narangifard2013` mentions the `Sommerfeld radiation condition <https://en.wikipedia.org/wiki/Sommerfeld_radiation_condition>`_ which formulates mathematically the condition that energy must emit from sources to infinity, instead of absorbing. This is needed in order to show that the solution is the unique one that is physical. + +.. math:: + \lim_{|x|\to\infty}|x|\left(\frac{\partial}{\partial |x|}-ik\right) u(x) = 0 + +with the convergence uniform in all directions approaching the origin. + +TODO: apply Sommerfeld to check feasibility of our solution + +Validity of multislice approximation +------------------------------------ + +As Narangifard notes, the assumption under which we can ignore the second $z$ derivative is Eq. 3.13: + +.. math:: + \left|2ik\frac{\partial}{\partial z}\phi(x)\right| \gg \left|\frac{\partial^2}{\partial z^2}\phi(x)\right|. + +The left-hand side is + +.. math:: + \left|2ik\frac{\partial}{\partial z}\phi(x)\right| = \left|u\cdot x + 2k k_z'\right| |\phi(x)|. + +The right-hand side is + +.. math:: + \frac{\partial^2}{\partial z^2}\phi(x) = \left(i\frac{1}{2k}u_z - \left(\frac{1}{2k}u\cdot x + k_z'\right)^2\right)\phi(x). + +The magnitude of this quantity is + +.. math:: + \left|\frac{\partial^2}{\partial z^2}\phi(x)\right| = \sqrt{\frac{u_z^2}{4k^2} + \left(\frac{1}{2k}u\cdot x + k_z'\right)^4}|\phi(x)|. + +The condition under which the multislice method is valid in the presence of external bias in direction $u$ is then + +.. math:: + \left|u\cdot x + 2k k_z'\right|^2 \gg \frac{u_z^2}{4k^2} + \left|\frac{1}{2k}u\cdot x + k_z'\right|^4 + +TODO: what is the interpretation of this condition? In the case of $u=0$, we can plug in the form of $k_z'$ derived above to see that we must have $4k^2 \gg \|k'\|^2$. What about in our case? + +Deriving the propagator +======================= diff --git a/examples/lightning.py b/examples/lightning.py index b9c2df7a19b7141d9ad57f38021020e2d7f7c987..b88f9cd195c1aa318c428425f02bc0713e16c1ab 100644 --- a/examples/lightning.py +++ b/examples/lightning.py @@ -12,14 +12,35 @@ import ptychopath as pt import ptychopath.data as ptd from ptychopath.data.datamodule import STEMDataModule +import os +import logging + class PlotParams(pl.Callback): - def __init__(self, every_n_train_steps): + def __init__(self, every_n_train_steps=None, every_n_train_epochs=None): super().__init__() self.every_n_train_steps = every_n_train_steps + self.every_n_train_epochs = every_n_train_epochs self.step = 0 + self.epoch = 0 def on_train_epoch_start(self, trainer, pl_module): + if ( + self.every_n_train_epochs is not None + and self.epoch % self.every_n_train_epochs == 0 + ): + self.save_plots(trainer, pl_module) + self.epoch += 1 + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, unused=0): + if ( + self.every_n_train_steps is not None + and self.step % self.every_n_train_steps == 0 + ): + self.save_plots(trainer, pl_module) + self.step += 1 + + def save_plots(self, trainer, pl_module): dz, dy, dx = pl_module.potential.voxel_spacing V = pl_module.potential() # plot the potential in central slices and log to tensorboard @@ -57,10 +78,14 @@ class PlotParams(pl.Callback): fig, global_step=self.step, ) + plt.close(fig) plot_potential_save(V, 0, "potential") plot_potential_save(V, 1, "potential") plot_potential_save(V, 2, "potential") + # 2D FFT + FV = torch.fft.fftshift(torch.fft.fft2(V).abs()) + plot_potential_save(FV, 0, "potential_fft") phase = torch.remainder( V * pl_module.multislice.interaction_constant, 2 * np.pi, @@ -115,26 +140,28 @@ class PlotParams(pl.Callback): fig, global_step=self.step, ) - fig = plt.figure() - if isinstance(pl_module.ccd_psf, nn.Module): - psf = pl_module.ccd_psf.weight.data - else: - psf = pl_module.ccd_psf_weight - plt.imshow( - psf.detach().squeeze().cpu().numpy(), - interpolation="nearest", - ) - plt.xlabel("X (voxels)") - plt.ylabel("Y (voxels)") - plt.colorbar() - pl_module.logger.experiment.add_figure( - "ccd_psf_weight", - fig, - global_step=self.step, - ) - plt.close() - plt.ion() - self.step += 1 + if pl_module.hparams.convolve_cbed: + fig = plt.figure() + if isinstance(pl_module.ccd_psf, nn.Module): + psf = pl_module.ccd_psf.weight.data + else: + psf = pl_module.ccd_psf_weight + plt.imshow( + psf.detach().squeeze().cpu().numpy(), + interpolation="nearest", + ) + plt.xlabel("X (voxels)") + plt.ylabel("Y (voxels)") + plt.colorbar() + pl_module.logger.experiment.add_figure( + "ccd_psf_weight", + fig, + global_step=self.step, + ) + plt.close() + plt.ion() + + plt.close("all") class PtychoLightning(pl.LightningModule): @@ -157,18 +184,89 @@ class PtychoLightning(pl.LightningModule): bw_limit=2 / 3.0, ) - # self.ccd_psf is optionally applied in exit wave reciprocal space - # before interpolation happens which rescales to CCD (CBED) space. - h, w = 9, 9 - y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) - sig_psf = 3.0 - self.ccd_psf = nn.Conv2d(1, 1, 9, bias=False, padding=4) - self.ccd_psf.weight.data = torch.exp( - -(((y - (h - 1) / 2) / sig_psf) ** 2 + ((x - (w - 1) / 2) / sig_psf) ** 2) - ).view(1, 1, h, w) + if self.hparams.convolve_cbed: + # self.ccd_psf is optionally applied in exit wave reciprocal space + # before interpolation happens which rescales to CCD (CBED) space. + h, w = 9, 9 + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) + sig_psf = 3.0 + self.ccd_psf = nn.Conv2d(1, 1, 9, bias=False, padding=4) + self.ccd_psf.weight.data = torch.exp( + -( + ((y - (h - 1) / 2) / sig_psf) ** 2 + + ((x - (w - 1) / 2) / sig_psf) ** 2 + ) + ).view(1, 1, h, w) + + self.register_buffer( + "cbed_resampling_grid", + self.microscope.cbed_resampling_grid( + (1, 1, *self.potential.voxels), + self.potential.voxel_spacing[-2:], + ), + ) + + if self.hparams.initialize_with_dpc: + dpc = torch.load(os.path.join(self.hparams.root, "dpc.pth")) + assert isinstance( + self.potential, pt.potential.DirectPotential + ), "`initialize_with_dpc` requires direct potential parametrization" + # corrected phase is in units of mrad + corphase = dpc["corrected_phase"] + dpc_offset = torch.as_tensor(dpc["offset"]) + dpc_spacing = torch.as_tensor(dpc["spacing"]) + # phase is computed on a shifted grid, since we use forward + # differences of the COM vector field. We must correct that here. + phase_offset = dpc_offset + 0.5 * dpc_spacing + # get the input center point, relative to the middle point, in voxels + midpt_pix = 0.5 * (torch.as_tensor(corphase.shape[-2:]) - 1) + midpt_world = phase_offset + dpc_spacing * midpt_pix + # origin (optic axis) location in pixel coordinates relative to + # center of phase image + center_pix = -midpt_world / dpc_spacing + + # subtract minimum so that we have all positive phase + # This is mostly inconsequential, but since we initialize to zero + # outside the FOV it has an affect on the appearance of the + # converged potential image + + # regrid to account for differences in offset and spacing + phase = pt.regrid.similarity2d( + corphase, + input_spacing=dpc_spacing, + input_center=center_pix, + output_center=(0, 0), + rotation=0.0, + output_spacing=self.potential.voxel_spacing[-2:], + output_shape=self.potential.voxels[-2:], + padding_mode="reflection", # zeros, reflection, or border + ).unsqueeze(0) + + # convert from phase (mrad) to potential (Volts) + from pyms.structure_routines import interaction_constant + + sigma = interaction_constant(1e3 * microscope.electron_energy) + thickness = self.potential.extent[-3] + print( + f"Dividing DPC phase (mrad) by thickness ({thickness} Angstrom) " + f"times sigma ({sigma} rad / (V Angstrom)) " + f"times 1e3 mrad / rad = {1e3 * sigma * thickness}" + ) + self.potential.V.data[...] = phase / (1e3 * sigma * thickness) @staticmethod def add_model_specific_args(parser): + parser.add_argument( + "--sqrt_intensity", + action="store_true", + help="Take square root of CBED and simulated intensity", + ) + parser.add_argument( + "--gain_estimation", + default="gain", + choices=("none", "gain", "gain_and_offset"), + help="For rescaling CBED, use closed-form optimal gain and offset?", + ) parser.add_argument( "--optimizer", default="SGD", @@ -211,9 +309,14 @@ class PtychoLightning(pl.LightningModule): action="store_true", help="Convolve by Gaussian point-spread function before interpolating CBED", ) + parser.add_argument( + "--initialize_with_dpc", + action="store_true", + help="If true, initialize potential with DPC, i.e. integrate the corrected center of mass image", + ) return parser - def training_step(self, batch, idx): + def training_step(self, batch, idx, log=True): cbed, probepos = batch # assume bs=1 cbed = cbed[0, 0] @@ -227,31 +330,73 @@ class PtychoLightning(pl.LightningModule): probepos_vox = self.potential.world2index(probepos) sim = self.multislice(self.probe_f, V, probe_shift=probepos_vox) - simcbed = torch.fft.fftshift(sim.abs(), dim=[-2, -1]) + simintensity = torch.fft.fftshift(sim.abs() ** 2, dim=[-2, -1]) # Resample to match CBED resolution and shift # Convolve by the detector's PSF before rescaling to prevent aliasing if self.hparams.convolve_cbed: - simcbed = self.ccd_psf(simcbed.view(1, 1, *simcbed.shape)).squeeze(0) + simintensity = self.ccd_psf( + simintensity.view(1, 1, *simintensity.shape) + ).squeeze(0) else: - simcbed = simcbed.unsqueeze(0) - simcbed = self.microscope.resample_cbed(simcbed, self.hparams.voxel_spacing) + simintensity = simintensity.unsqueeze(0) + simcbed = pt.regrid.grid_sample_complex( + simintensity.unsqueeze(0), + self.cbed_resampling_grid, + ).squeeze() - cbed = cbed.type(simcbed.dtype).sqrt() + cbed = cbed.type(simcbed.dtype) simcbed = simcbed.squeeze() - # optimal gain computation is straightforward - gain = (simcbed * cbed).sum() / (simcbed ** 2).sum() - simcbed = simcbed * gain + if self.hparams.sqrt_intensity: + cbed = cbed.clamp(0, None).sqrt() + simcbed = simcbed.sqrt() + + # Compute optimal gain (and offset) in closed form + if self.hparams.gain_estimation == "gain_and_offset": + # solve for gain and offset + # That is, solve + # min_{g,o} |g * X + o - Y|^2 + # where X=simcbed and Y=cbed. + # The optimality conditions are given by the following 2x2 system: + # sum(X) * g + n * o = sum(Y) + # X^TX * g + sum(X) * o = X^TY + # A direct inverse exists under the condition + # sum(X)^2 != n X^TX = n sum_i(X_i^2) + # + l2X = (simcbed ** 2).mean() + XdotY = (simcbed * cbed).mean() + a = simcbed.mean() + d = a + b = 1 + c = l2X + t = cbed.mean() + u = XdotY + + det = a * d - b * c + gain = (d * t - b * u) / det + offset = (a * u - c * t) / det + + simcbed = simcbed * gain + offset + elif self.hparams.gain_estimation == "gain": + gain = (simcbed * cbed).sum() / (simcbed ** 2).sum() + simcbed = simcbed * gain + elif self.hparams.gain_estimation == "none": + pass + else: + raise RuntimeError( + "Invalid choice for `gain_estimation`: " + self.hparams.gain_estimation + ) loss = F.mse_loss(simcbed, cbed) - self.log("train_mse", loss.item()) + + if log: + self.log("train_mse", loss.item()) return loss def configure_optimizers(self): from itertools import chain - # params = [self.V] params = self.potential.parameters() if self.hparams.optimizer == "SGD": @@ -293,7 +438,7 @@ if __name__ == "__main__": parser = pl.Trainer.add_argparse_args(parser) parser.add_argument( "--plot_every_n_steps", - default=10, + default=50, type=int, help="Plot potential every N iterations", ) @@ -305,7 +450,7 @@ if __name__ == "__main__": ) parser.add_argument( "--checkpoint_every_n_epochs", - default=1, + default=None, type=int, help="Save all training state every N epochs", ) @@ -336,10 +481,21 @@ if __name__ == "__main__": datamodule.setup(stage="fit") print("Microscope:", datamodule.microscope) + if args.checkpoint_every_n_steps is None and args.checkpoint_every_n_epochs is None: + logging.warning( + "Both --checkpoint_every_n_steps and " + "--checkpoint_every_n_epochs were omitted, which means " + "checkpointing is disabled altogether. Normally one or the other " + "of these (but not both) is desired." + ) + trainer = pl.Trainer.from_argparse_args( args, callbacks=[ - PlotParams(every_n_train_steps=args.plot_every_n_steps), + PlotParams( + every_n_train_steps=args.plot_every_n_steps, + every_n_train_epochs=1, + ), pl.callbacks.ModelCheckpoint( verbose=True, every_n_train_steps=args.checkpoint_every_n_steps, @@ -358,9 +514,10 @@ if __name__ == "__main__": trainer.test(mod, datamodule=datamodule) sys.exit(0) + datamodule.prepare_data() + datamodule.setup(stage="fit") + mod = PtychoLightning(datamodule.microscope, potential, **vars(args)) print("Trainable parameters:", [k for k, _ in mod.named_parameters()]) - datamodule.prepare_data() - datamodule.setup(stage="fit") trainer.fit(mod, train_dataloaders=datamodule.train_dataloader()) diff --git a/notebooks/FitAperture.ipynb b/notebooks/FitAperture.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..04178624f495f9fbbeff64b8a62d5cf494451f7d --- /dev/null +++ b/notebooks/FitAperture.ipynb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc400ec41d412aa6c4d126e77e0679c1b2a2b5d88495eb286b9357ac4326d640 +size 6802820 diff --git a/notebooks/Wigner.ipynb b/notebooks/Wigner.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f94df6b24e8b0e78c5af7f23dde0bed2bd12e230 --- /dev/null +++ b/notebooks/Wigner.ipynb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ee3c4629b23048e5509c4dd434068ea6f1fff48416bbd377649a2ad518477e6 +size 13316 diff --git a/notebooks/ptosizes.ipynb b/notebooks/ptosizes.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9eb820d8aeada865d24984de45851c75fe46cf22 --- /dev/null +++ b/notebooks/ptosizes.ipynb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a2679b393d7097d609d06601d49d953fb7551b301aeba2fdbb7efe5d0b56be2 +size 395570 diff --git a/ptychopath/__init__.py b/ptychopath/__init__.py index 36027e5128c6d5f7f29274eeda6ca718c3e57434..cedc6aee42f46f05ec5c712fd030f95304fec3e8 100644 --- a/ptychopath/__init__.py +++ b/ptychopath/__init__.py @@ -19,3 +19,4 @@ from . import fresnel from . import microscope from . import multislice from . import potential +from . import regrid diff --git a/ptychopath/data/__init__.py b/ptychopath/data/__init__.py index 8665385f7b5847ed899a383ec8c3b10f209cf07f..277799c8f6e8e2621b9b72aaf3f6fb3a16ccd0be 100644 --- a/ptychopath/data/__init__.py +++ b/ptychopath/data/__init__.py @@ -1,5 +1,6 @@ """Datasets for ptychography experiments.""" +from .nion import * from .ptosim import * from .random import * @@ -7,6 +8,7 @@ from .random import * all_datamodules = { "ptosim": PTODataModule, "random": RandomSTEMDataModule, + "nion": NionDataModule, } diff --git a/ptychopath/data/nion.py b/ptychopath/data/nion.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab16ae14e187d7ec2f104dc23b16e6859776eda --- /dev/null +++ b/ptychopath/data/nion.py @@ -0,0 +1,446 @@ +"""STEMDataModules for data from the Nion instrument at CNMS.""" +import argparse +import h5py +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from tqdm import tqdm + +import json +import math +import os +from pathlib import Path + +from .datamodule import STEMDataModule, STEMDataset +from ..microscope import calibration +from ..microscope.ccd import CCD +from ..microscope.conversions import electron_wavelength +from ..microscope.aberrations import fifth_order_aberrations, Aberrations +from ..microscope.microscope import Microscope + +__all__ = ["NionDataModule"] + + +def read_microscope_nion(d): + """Read microscope from Nion dict format.""" + # First verify this is a 4D-STEM dataset + cbed_dim = d["collection_dimension_count"] + scan_dim = d["datum_dimension_count"] + assert ( + scan_dim == 2 + ), f"Assumed 2D scan pattern, found datum_dimension_count={scan_dim}" + assert ( + cbed_dim == 2 + ), f"Assumed 2D CBED patterns, found collection_dimension_count={cbed_dim}" + + # A description of these aberrations and more can be found at + # https://www.globalsino.com/EM/page3740.html + ab = {} + foa = fifth_order_aberrations + a = d["metadata"]["instrument"]["ImageScanned"] + # NOTE: Nion writes lengths in meters unless specified, so we convert to Angstrom + for k in foa.index: + row = foa.loc[k] + kriv = f"C{row.n}{row.m}" + if row.m == 0: # C10, C30, C50 are real + x = 1e10 * float(a.get(kriv, 0)) + else: # complex, stored as .a and .b (real/imag) + x = { + "real": 1e10 * a.get(kriv + ".a", 0), + "imag": 1e10 * a.get(kriv + ".b", 0), + } + ab[k] = x + print(f"Aberrations: {ab}") + ab = Aberrations.from_dict(ab) + + # convert voltage (high_tension) to keV + ev = float(d["metadata"]["hardware_source"]["high_tension"]) * 1e-3 + + # read Detector info + cal = d["spatial_calibrations"] + assert cal[2]["units"] == "rad" + pixel_spacing = (1e3 * cal[2]["scale"], 1e3 * cal[3]["scale"]) + fov_mrad = (-2e3 * cal[2]["offset"], -2e3 * cal[3]["offset"]) + pixels = ( + int(fov_mrad[0] / pixel_spacing[0]), + int(fov_mrad[1] / pixel_spacing[1]), + ) + print("E (keV):", ev) + print("FOV (mrad):", fov_mrad) + print("spacing (mrad):", pixel_spacing) + ccd = CCD( + pixels, + pixel_spacing, + (0.0, 0.0), + None, # NOTE: this should be measured + ) + + return Microscope( + ev, + None, # NOTE: this should be measured + ab, + ccd, + ) + + +class NionDataModule(STEMDataModule): + + poc = "Ondrej Dyck <dyckoe@ornl.gov>, Stephen Jesse <sjesse@ornl.gov>, Debangshu Mukherjee <mukherjeed@ornl.gov>" + + def __init__( + self, + batch_size, + num_workers, + nion_json, + root, + y_range=None, + x_range=None, + force_defocus=None, + force_rotation=None, + additional_rotation_degrees=0.0, + wiener_reg=1e-2, + flip_x=False, + aperture_threshold_level=0.3, + ): + """ + Data module for Nion data. + + Nion provides data in paired .json and .npy files. This datamodule reads + those in and prepares them in our standardized directory containing JSON + files. + + Args: + nion_json: str: + Location of the .json file + root: str: + Location to write preprocessed directory + y_range: Tuple[int]: + Limit to only these scanlines of probe positions. If given (a, + b), use scanlines a through b-1. + x_range: Tuple[int]: + Limit each probe position scan line to these positions. + force_defocus: float: + If given, override defocus found in .json file with this value + (in Angstrom). + force_rotation: float: + If given, override estimated CBED rotation by this angle, in + radians. + additional_rotation_degrees: float: + Add this amount (in degrees) to cbed rotation. Commonly used to + add 180 degrees due to DPC rotation calibration ambiguity. + wiener_reg: float: + Amount of regularization in Wiener filter for initial DPC phase + estimate. + flip_x: bool: + Flip each CBED pattern in the X direction, to compensate for + occasional readout irregularity. + aperture_threshold_level: float: + Level, relative to maximum of average CBED at which to threshold + average CBED for fitting a circle in aperture calibration. + """ + super().__init__(root, batch_size, num_workers) + self.save_hyperparameters() + + def dependent_hparams(self): + """Return only the hparams that affect caching into .pth files. + + .. note:: + If :meth:`prepare` or :meth:`add_argparse_args` are altered, check + that any hparams introduced are represented here in order to prevent + cache corruption. + """ + hp = {} + for k in [ + "aperture_threshold_level", + "y_range", + "x_range", + "force_defocus", + "force_rotation", + "additional_rotation_degrees", + "flip_x", + "aperture_threshold_level", + ]: + hp[k] = self.hparams[k] + # add the directory name of the unprocessed_dir. This will help prevent + # mixups between neutral and ionic + hp["json_basename"] = os.path.basename(self.hparams.nion_json) + return hp + + def prepare_data(self): + """Collect raw simulation data into torch-ready CSV+PTH directory.""" + root = Path(self.hparams.root) + if os.path.exists(root / "index.csv"): + # check hparams for inconsistencies + hpstored = json.load(open(root / "hparams.json")) + hp = self.dependent_hparams() + if hpstored != hp: + raise RuntimeError( + f"Existing root directory {root} was created with " + f"hyperparameters {hpstored} which does not match {hp}. " + "Refusing to overwrite." + ) + return # root directory already prepared + imdir = root / "cbeds" + os.makedirs(imdir, exist_ok=True) + + nion_dirname = os.path.dirname(os.path.abspath(self.hparams.nion_json)) + nion_basename = os.path.basename(self.hparams.nion_json) + base, ext = os.path.splitext(nion_basename) + npyfile = os.path.join(nion_dirname, base + ".npy") + + d = json.load(open(self.hparams.nion_json)) + from pprint import pprint + + pprint(d) + + microscope = read_microscope_nion(d) + + if self.hparams.force_defocus is not None: + microscope.aberrations.C1 = self.hparams.force_defocus + + # Load memory-mapped CBEDs + cbeds = np.load(npyfile, mmap_mode="r") + + assert list(cbeds.shape[:2]) == d["metadata"]["scan"]["scan_size"] + if self.hparams.y_range is not None: + # restrict to a subset of scanlines if requested + start, end = self.hparams.y_range + cbeds = cbeds[start:end] + if self.hparams.x_range is not None: + # restrict to a subset of scanlines if requested + start, end = self.hparams.x_range + cbeds = cbeds[:, start:end] + + if self.hparams.flip_x: + print("Flipping CBEDs in X direction") + cbeds = cbeds[..., ::-1] + + # We will pass over all the CBED data to convert to .pth + # While we do so, we will also compute the average CBED using a running + # average, as well as the COM of each CBED. + # After we're done looping, we'll use the average CBED and the COM + # vector field to calibrate the aperture and detector rotation. + N = cbeds.shape[0] * cbeds.shape[1] + cbed_av = torch.zeros(cbeds.shape[-2:], dtype=torch.float64) + com_pixels = torch.zeros(2, *cbeds.shape[:2]) + + # Use numpy which has a stable api for meshgrid unlike current pytorch + ys, xs = np.meshgrid( + np.arange(cbeds.shape[-2]), + np.arange(cbeds.shape[-1]), + indexing="ij", + ) + ys = torch.as_tensor(ys) + xs = torch.as_tensor(xs) + + filenames = [] + positions = [] + cal = d["spatial_calibrations"] + assert cal[0]["units"] == "nm" + # convert probe grid from nm to angstrom + probe_start = (cal[0]["offset"] * 10.0, cal[1]["offset"] * 10.0) + probe_step = (cal[0]["scale"] * 10.0, cal[1]["scale"] * 10.0) + + for y in tqdm(range(cbeds.shape[0]), desc="Preprocessing CBED lines"): + realy = y if self.hparams.y_range is None else y + self.hparams.y_range[0] + probe_y_ang = probe_start[0] + realy * probe_step[0] + for x in range(cbeds.shape[1]): + realx = ( + x if self.hparams.x_range is None else x + self.hparams.x_range[0] + ) + probe_x_ang = probe_start[1] + realx * probe_step[1] + positions.append((probe_y_ang, probe_x_ang)) + filename = f"cbed_{y}_{x}.pth" + c = torch.tensor(cbeds[y, x, :, :].copy()) + + # compute average CBED and COM on the fly + cbed_av += c.type(torch.float64) / N + cnormed = c / c.mean() + + # Note ys and xs are in _pixel_ units, so this computes COM in + # pixels + com_pixels[0, y, x] = (cnormed * ys).mean() + com_pixels[1, y, x] = (cnormed * xs).mean() + + c = c.unsqueeze(0) # store as CHW, with C=1 + filenames.append(filename) + torch.save(c, imdir / filename) + + # save average CBED + torch.save(cbed_av, root / "average_cbed.pth") + + # estimate aperture from a thresholded average CBED + aperture, beam_axis = calibration.aperture.calibrate_aperture( + cbed_av, + pixel_spacing=microscope.ccd.pixel_spacing, + ys=ys, + xs=xs, + threshold=self.hparams.aperture_threshold_level, + ) + microscope.aperture = aperture + microscope.ccd.beam_axis = beam_axis + print(f"Calibrated beam center is {microscope.ccd.beam_axis} pixels.") + print(f"Calibrated aperture is {microscope.aperture} mrad.") + + # Subtract from com_pixels so COM is relative to optic axis + com_pixels -= ( + torch.tensor(beam_axis) + (torch.tensor(cbeds.shape[:2]) - 1) / 2 + ).view(2, 1, 1) + # multiply by ccd pixel scale to get mrad + com_mrad = com_pixels * torch.tensor(microscope.ccd.pixel_spacing).view(2, 1, 1) + + if self.hparams.force_rotation is not None: + microscope.ccd.rotation = self.hparams.force_rotation + else: + # estimate rotation from DPC using COM vector field + rot = calibration.dpc.calibrate_rotation(com_mrad) + print( + f"Calibrated detector rotation is {rot} radians" + f" {rot * 180. / np.pi} degrees" + ) + rot += (np.pi / 180.0) * self.hparams.additional_rotation_degrees + microscope.ccd.rotation = rot + print( + f"Further rotated detector rotation to {microscope.ccd.rotation} radians" + f" {microscope.ccd.rotation * 180. / np.pi} degrees" + ) + + # Correct COM by applying theta in the CLOCKWISE direction + corcom = com_mrad.clone() + c = np.cos(microscope.ccd.rotation) + s = np.sin(microscope.ccd.rotation) + corcom[0] = c * com_mrad[0] + s * com_mrad[1] + corcom[1] = -s * com_mrad[0] + c * com_mrad[1] + + # Compute potential using DCT approach, from COM vector field + corphase = calibration.dpc.com2phase( + corcom, spacing=probe_step, reg=self.hparams.wiener_reg + ) + + # save offset and spacing of probes + comdict = { + "corrected_phase": corphase, + "corrected_com": corcom, + "uncorrected_com": com_mrad, + "offset": ( + probe_start[0] + + probe_step[0] + * (0 if self.hparams.y_range is None else self.hparams.y_range[0]), + probe_start[1] + + probe_step[1] + * (0 if self.hparams.x_range is None else self.hparams.x_range[0]), + ), + "spacing": probe_step, + } + torch.save( + comdict, + root / "dpc.pth", + ) + + # save hparams (except those that may change) in order to check for + # inconsistencies + json.dump( + self.dependent_hparams(), + open(root / "hparams.json", "w"), + sort_keys=True, + indent=2, + ) + + pprint(microscope.to_dict()) + json.dump( + microscope.to_dict(), + open(root / "microscope.json", "w"), + sort_keys=True, + indent=2, + ) + + # write index file as last step, since this is how we check whether the + # directory is already prepared or not + df = pd.DataFrame( + { + "filename": filenames, + "probe_x_ang": [x for x, _ in positions], + "probe_y_ang": [y for _, y in positions], + } + ) + df.to_csv(root / "index.csv") + + @staticmethod + def add_argparse_args(parser): + parser = STEMDataModule.add_argparse_args(parser) + parser.add_argument( + "--nion_json", + type=str, + help="Location of original Nion data json file (.npy filename determined from this)", + ) + parser.add_argument( + "--y_range", + type=int, + nargs=2, + default=None, + help="If provided as two integers a and b, limit to only scanlines from a thru b-1.", + ) + parser.add_argument( + "--x_range", + type=int, + nargs=2, + default=None, + help="If provided as two integers a and b, limit each scanline only pixels from a thru b-1.", + ) + parser.add_argument( + "--force_defocus", + type=float, + default=None, + help="If provided, manually set this defocus value in Angstrom.", + ) + parser.add_argument( + "--flip_x", + action="store_true", + help="Flip CBEDs in X direction.", + ) + parser.add_argument( + "--force_rotation", + type=float, + default=None, + help="If given, use this rotation (in radians) instead of estimating it using DPC.", + ) + parser.add_argument( + "--additional_rotation_degrees", + type=float, + default=0.0, + help="Amount of additional rotation to apply to CBEDs. Unused if --force_rotation is provided.", + ) + parser.add_argument( + "--wiener_reg", + type=float, + default=1e-2, + help="Amount of regularization for DPC Wiener deconvolution.", + ) + parser.add_argument( + "--aperture_threshold_level", + type=float, + default=0.3, + help="Float between 0 and 1 indicating the level relative to the " + "maximum, that we threshold the average CBED pattern at to fit a " + "circle and determine the condensing aperture.", + ) + return parser + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser = NionDataModule.add_argparse_args(parser) + args = parser.parse_args() + + with torch.no_grad(): + dm = NionDataModule.from_argparse_args(args) + dm.prepare_data() + dm.setup() + dl = dm.train_dataloader() + cbeds, probepos = next(iter(dl)) + print("Batch size:", dm.hparams.batch_size) + print("CBED batch shape:", cbeds.shape) + print("Probe position batch shape:", probepos.shape) diff --git a/ptychopath/data/ptosim.py b/ptychopath/data/ptosim.py index 676be404b9d392663123628815fbe38252eb8a97..9ea06e3a5fec7ba2485f52c93b37296cb9388b3f 100644 --- a/ptychopath/data/ptosim.py +++ b/ptychopath/data/ptosim.py @@ -56,15 +56,15 @@ def read_microscope_h5(h): # use it as the CCD pixel spacing resulting in an identity transform between # psi_f and CCD. unitcell = h["Crystal"].attrs["a0"] - fovx = unitcell[0] * c["Nxtile"] fovy = unitcell[1] * c["Nytile"] + fovx = unitcell[0] * c["Nxtile"] ev = float(a["Energy"]) lam = float(electron_wavelength(ev)) - spacing = (fovx / lam, fovy / lam) + spacing = (1e3 * lam / fovy, 1e3 * lam / fovx) print("Aperture:", a["Aperture"]) print("Tiling:", c["Nxtile"], c["Nytile"]) print("unitcell:", unitcell) - print("FOV (ang):", fovx, fovy) + print("FOV (ang):", fovy, fovx) print("E (keV):", ev) print("wavelength (ang):", lam) print("spacing (mrad):", spacing) @@ -72,6 +72,7 @@ def read_microscope_h5(h): pixels, spacing, (0.0, 0.0), + 0.0, ) return Microscope( diff --git a/ptychopath/data/random.py b/ptychopath/data/random.py index 95194012a6a6c95dbb0cb93d15997efddfdc9e46..d895e8d50b7d146dd2c26f202808d0f1eeaec831 100644 --- a/ptychopath/data/random.py +++ b/ptychopath/data/random.py @@ -86,6 +86,7 @@ class RandomSTEMDataModule(STEMDataModule): (self.hparams.ccd_size[0] - 1) / 2, (self.hparams.ccd_size[1] - 1) / 2, ), + "rotation": 0.0, }, "electron_energy_kev": 200.0, } diff --git a/ptychopath/microscope/calibration/__init__.py b/ptychopath/microscope/calibration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c3dc0810b94cc148a83e8fed373e9c73369a66a --- /dev/null +++ b/ptychopath/microscope/calibration/__init__.py @@ -0,0 +1,5 @@ +""" +Calibration of microscope parameters from CBED data. +""" + +from . import aperture, dpc diff --git a/ptychopath/microscope/calibration/aperture.py b/ptychopath/microscope/calibration/aperture.py new file mode 100644 index 0000000000000000000000000000000000000000..7b252707bcfcf4bd8b5d1b3b75c15c56b7da8d77 --- /dev/null +++ b/ptychopath/microscope/calibration/aperture.py @@ -0,0 +1,66 @@ +""" +Calibration of aperture using moments +""" + +import numpy as np +import torch + +__all__ = ["calibrate_aperture"] + + +def calibrate_aperture(cbed_avg, pixel_spacing, threshold, ys=None, xs=None): + """ + Calibrate both aperture and beam axis using average CBED. + + This thresholds the average CBED at `threshold * cbed_avg.max()` then + fits a circle. + + Args: + cbed_avg: torch.Tensor: + The average of all cbed patterns in a dataset, as a 2D tensor + with HW dimension ordering. + pixel_spacing: Tuple[float]: + Spacing in YX of the detector + threshold: float: + Portion of maximum at which to threshold average CBED for + determining circle. This should be around 0.2-0.7 and defaults + to 0.3. + ys: torch.Tensor: + Tensor holding y coordinates at each probe position, in pixels. Will + be computed in pixel units if omitted. + xs: torch.Tensor: + Tensor holding x coordinates at each probe position, in pixels. Will + be computed in pixel units if omitted. + Returns: + (float, (float, float)): + The estimated aperture radius, in milliradians, and the beam axis, + in YX ordering, in pixel units. + """ + # Estimate beam center from COM of thresholded avg CBED + apmask = (cbed_avg > cbed_avg.max() * threshold).type(torch.float64) + apmask /= apmask.mean() + if ys is None or xs is None: + # Use numpy which has a stable api for meshgrid unlike current pytorch + ys, xs = np.meshgrid( + np.arange(cbed_avg.shape[-2]), + np.arange(cbed_avg.shape[-1]), + indexing="ij", + ) + ys = torch.as_tensor(ys) + xs = torch.as_tensor(xs) + + meany = (ys * apmask).mean().item() + meanx = (xs * apmask).mean().item() + beam_axis = ( + meany - (cbed_avg.shape[-2] - 1) / 2, + meanx - (cbed_avg.shape[-1] - 1) / 2, + ) + # Get locations relative to center, in mrad + ymm = (ys - meany) * pixel_spacing[0] + xmm = (xs - meanx) * pixel_spacing[1] + secmom = ((ymm ** 2 + xmm ** 2) * apmask).mean() + # Integral of r^2 over a circle is pi R^4 / 2. + # We normalized by dividing by the area of the circle, pi R^2, + # so the second moment (secmom) should be R^2 / 2 + aperture = torch.sqrt(secmom * 2).item() + return aperture, beam_axis diff --git a/ptychopath/microscope/calibration/dpc.py b/ptychopath/microscope/calibration/dpc.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9f49a73f05502f6a010ddc578f41379c9ddab8 --- /dev/null +++ b/ptychopath/microscope/calibration/dpc.py @@ -0,0 +1,320 @@ +""" +Calibrations using differential phase contrast (DPC). +""" + +import numpy as np +import torch + +__all__ = ["calibrate_rotation", "com2phase", "invert_laplacian_dct_wiener"] + + +def dy(f, spacing): + fy = f[1:, :] - f[:-1, :] + return (fy[:, 1:] + fy[:, :-1]) / spacing + + +def dx(f, spacing): + fx = f[:, 1:] - f[:, :-1] + return (fx[1:, :] + fx[:-1, :]) / spacing + + +def div(v, spacing): + dvyy = dy(v[0], spacing[0]) + dvxx = dx(v[1], spacing[1]) + return dvyy + dvxx + + +def curl(v, spacing): + dvyx = dx(v[0], spacing[1]) + dvxy = dy(v[1], spacing[0]) + return dvyx - dvxy + + +def tile_mirrored(f): + # tile image with mirroring + H, W = f.shape + tiled = torch.zeros(2 * f.shape[0], 2 * f.shape[1]) + tiled[:H, :W] = f + tiled[H:, :W] = torch.flipud(f) + tiled[:H, W:] = torch.fliplr(f) + tiled[H:, W:] = torch.flipud(torch.fliplr(f)) + return tiled + + +def invert_laplacian_dct_wiener(f, spacing, reg=1e8): + """ + Invert the Laplacian operator with a simple Wiener filter. + + This tiles a mirrored version of the vector field `f` in order to operate on + the discrete cosine transform (DCT) instead of the Fourier transform, + avoiding wraparound artifacts during deconvolution[^1]. See + :cite:t:`ishizuka17` for more detail. + + [^1]: https://en.wikipedia.org/wiki/Wiener_deconvolution + + Args: + f: torch.Tensor: + Function to be deconvolved + spacing: Tuple[float]: + Grid spacing of `f` + reg: float: + Amount of regularization for Wiener deconvolution. + Returns: + torch.Tensor: + Deconvolved function. + """ + m = tile_mirrored(f) + HH, WW = m.shape + # Discrete second derivatives in each direction + # continuous + # d2y = -(((2 * np.pi) * torch.fft.fftfreq(HH).unsqueeze(1)) ** 2) + # d2x = -(((2 * np.pi) * torch.fft.rfftfreq(WW).unsqueeze(0)) ** 2) + # iterated central first derivs + # d2y = ( + # torch.sin((2 * np.pi) * torch.fft.fftfreq(HH).unsqueeze(1)) / + # spacing[0]) ** 2 + # d2x = ( + # torch.sin((2 * np.pi) * torch.fft.rfftfreq(WW).unsqueeze(0)) / + # spacing[1]) ** 2 + # discrete second deriv + d2y = ( + 2 * (torch.cos((2 * np.pi) * torch.fft.fftfreq(HH)) - 1).unsqueeze(1) + ) / spacing[0] ** 2 + d2x = ( + 2 * (torch.cos((2 * np.pi) * torch.fft.rfftfreq(WW)) - 1).unsqueeze(0) + ) / spacing[1] ** 2 + H = d2y + d2x + F = torch.fft.rfft2(m) + S = F.abs() ** 2 # estimate of mean power spectral density in the input + HS = H * S + # Wiener filter applies filter and adjoint (Laplacian is self-adjoint) in + # denominator + HHS = H * HS # HHS is non-negative + G = HS / (HHS + reg) # reg is spectral power density of noise + phase = torch.fft.irfft2(G * F, s=(HH, WW)) + return phase[: f.shape[0], : f.shape[1]] + + +def com2phase(com, spacing, reg=1e-2): + r""" + Create an estimated phase using Poisson's law. + + A focused electron passing through a thin sample sees a roughly constant + potential gradient (electric field). Under that approximation, the + Fourier transform of the wavefunction is shifted by an amount + $\frac{\tau\sigma\nabla_\perp V}{2\pi}$ where $\tau$ is the sample + thickness, $\sigma$ the interaction constant for the electron, and + $\nabla_\perp V$ is the planar component of the gradient potential (i.e. the + negative electric field in the material). This means once we estimate the + shift in the CBED patterns via estimation of a center-of-mass (COM) vector + field, we can use the following relation to solve for a 2D potential $V$: + + .. math:: + + \tau\sigma \Delta V = 2\pi \diverg \mathrm{COM} + + where the Laplacian, $\Delta$, must be inverted after computing the + divergence of the COM. The quantity $\tau\sigma V$ is the phase imparted + by the thin object, and is what we return in this function. + + Args: + com: torch.Tensor: + The center of mass vector field in milliradians, as a 2HW tensor. + spacing: Tuple[float]: + Beam spacing for COM grid + reg: float: + Amount of regularization to use when constructing the Wiener filter. + Returns: + torch.Tensor: + The estimated phase as an HW shaped tensor, in milliradians. + """ + dc = div(com, spacing) + phase = (2 * np.pi) * invert_laplacian_dct_wiener(dc, spacing, reg=reg) + return phase + + +def calibrate_rotation(com_pix): + r""" + Given center of mass vector field (in pixels), compute detector + rotation. + + We assume that the CBED patterns have been artificially rotated + counter-clockwise (CCW) by an angle . This results in a CCW rotation of the + COM by . We wish to estimate $\theta$ so that we can use it in + regrid.similarity2d() to resample exit waves with the proper clockwise + correction. + + This method finds the original CCW rotation that, when rotated clockwise + (i.e. by $-\theta$) minimizes the curl of the vector field, which is only + determined up to a rotation of 180 degrees. + + Recall the formulas for divergence and curl in 2D. Note that curl is a + scalar field in the two-dimensional case, corresponding to the Z + component of a 3D curl of two vector fields having zero Z component. + + .. math:: + + \begin{align} + \diverg v &= \frac{\partial v_x}{\partial x} + \frac{\partial + v_y}{\partial y} \\ + \curl v &= \frac{\partial v_y}{\partial x} - \frac{\partial + v_x}{\partial y}. + \end{align} + + We first compute the curl and divergence using simple forward + differences. Then we observe that a clockwise rotation by angle $\theta$ of + each vector in a vector field results in a curl of + + .. math:: + + \begin{align} + \curl (R_{-\theta} v) &= + \frac{\partial (-\sin\theta v_x + \cos\theta v_y)}{\partial x} - + \frac{\partial + (\cos\theta v_x + \sin\theta v_y)}{\partial y} \\ + &= \cos\theta \curl v - \sin\theta \diverg v. + \end{align} + + In fact, the divergence satisfies a similar equation: + + .. math:: + + \diverg (R_\theta v) = \sin\theta \curl v + \cos\theta \diverg v. + + This implies that rotations stay within the span of the curl + and divergence in function space. We would like to minimize the + following for $\theta$: + + .. math:: + + \begin{align} + \hat{\theta} &= \argmin_\theta \|\curl\left(R_{-\theta} v\right)\|^2 = \argmin_\theta \|\cos\theta\curl v + \sin\theta + \diverg v\|^2. + \end{align} + + Recall the following parametrization of a 2D ellipse in nD: given two + vectors, $\mathbf{u}, \mathbf{w}\in\R^n$, we form an $n\times 2$ matrix + with $\mathbf{u}$ and $\mathbf{w}$ as columns. Then the ellipse is a + curve $\gamma:\R\to\R^n$. + + .. math:: + + \gamma(\theta) = \left(\mathbf{u}, \mathbf{w} + \right)\left(\begin{array}{c} \cos\theta \\ + \sin\theta + \end{array}\right) = M \left(\begin{array}{c} + \cos\theta \\ + \sin\theta + \end{array}\right) + + Our problem is equivalent to $\mathbf{u}=\curl v$ and + $\mathbf{w}=\diverg v$, and we'd like to find the semi-minor axis of + this ellipse, i.e. the point of minimum norm on this curve. + + To get the semiminor axis, we can find the SVD of the matrix $M$ or, + equivalently, the eigendecomposition of either $MM^T$ or $M^TM$. $MM^T$ + is of size $n\times n$, but $M^TM$ is of size $2\times 2$, which is much + easier to work with. The entries of the smaller matrix are as follows: + + .. math:: + + M^TM = \left( + \begin{array}{cc} + a & b \\ + b & d + \end{array} + \right) = \left( + \begin{array}{cc} + \|u\|^2 & u^T w \\ + w^T u & \|w\|^2 + \end{array} + \right) + + The eigenvalues of this matrix are available in terms of its trace and + determinant: + + .. math:: + + \begin{align} + T &= \|u\|^2+\|w\|^2 \\ + D &= \|u\|^2\|w\|^2 - \left(u^Tw\right)^2 \\ + \lambda &= \frac{T \pm \sqrt{T^2-4D}}{2} + \end{align} + + The minor eigenvalue corresponds to a negative sign in the numerator. We + want a corresponding eigenvector which we'll call $e$. + + .. math:: + + \begin{align} + \left(M^TM - \lambda I\right) e &= 0 \\ + \left(\begin{array}{cc} + a - \lambda & b \\ + b & d - \lambda + \end{array}\right) + \left(\begin{array}{c} + e_1 \\ + e_2 + \end{array}\right) &= + \left(\begin{array}{c} + 0 \\ + 0 + \end{array}\right). + \end{align} + + If $b\ne 0$ then set $e_1=b$. Then we have + + .. math:: + + \begin{align} + ab - \lambda b + b e_2 &= 0 \\ + e_2 &= \lambda - a. + \end{align} + + If $b=0$, i.e. the divergence and curl are orthogonal, then $u$ and $w$ + already form an orthogonal basis, and $\|u\|^2$ and $\|w\|^2$ are the + corresponding eigenvalues. In that case, we know that if $\|u\|^2\le + \|w\|^2$ the optimal rotation is zero or 180 degrees, and otherwise it + is 90 or 270 degrees. We can check for this condition explicitly or as + we will see, use :func:`~torch.atan2` to avoid singularities. + + In the case that $b\ne 0$, we determine the angle that $e$ makes with + the first coordinate axis: + + .. math:: + + \begin{align} + \hat{\theta} = \atan\left(\frac{e_2}{e_1}\right) &= + \atan\left(\frac{\lambda - a}{b}\right) \\ &= + \atan\left(\frac{\frac{T \pm \sqrt{T^2-4D}}{2} - a}{b}\right) \\ + &= \atan\left(\frac{d - a \pm \sqrt{(a+d)^2-4(ad-b^2}}{2b}\right) \\ + &= \atan\left(\frac{\|w\|^2 - \|u\|^2 \pm + \sqrt{(\|u\|^2+\|w\|^2)^2 - 4\left(\|u\|^2\|w\|^2- + \left(u^Tw\right)^2\right)}}{2u^Tw}\right) + \end{align} + + This is value of $\theta$ returned by this function. + + Args: + com_pix: torch.Tensor: + Center of mass at each probe position in pixels, as a tensor of + shape 2HW. + Returns: + float: + An optimal rotation angle in radians, i.e. one that when applied in + clockwise direction minimizes the L2 norm of the curl of the + corrected COM vector field. Note that adding any integer multiple of + $\pi$ maintains optimality. + """ + # compute div and curl of v with simple finite differences (assumes equal XY spacing in probe locations) + divv = div(com_pix, (1, 1)) + curlv = curl(com_pix, (1, 1)) + # compute all pairwise inner products between div and curl + a = (curlv ** 2).mean() + d = (divv ** 2).mean() + b = (divv * curlv).mean() + # Now find theta + return torch.atan2( + d - a - torch.sqrt((a + d) ** 2 - 4 * (a * d - b * b)), + 2 * b, + ).item() diff --git a/ptychopath/microscope/ccd.py b/ptychopath/microscope/ccd.py index c196713b6b7fbfbed431b9d047edc36cc22cd593..561d91f1ea6d3cc32ce737d12cdfb5b169dd4e55 100644 --- a/ptychopath/microscope/ccd.py +++ b/ptychopath/microscope/ccd.py @@ -1,3 +1,5 @@ +import torch + from dataclasses import dataclass __all__ = ["CCD"] @@ -12,13 +14,16 @@ class CCD: pixel_spacing: (float, float) """Pixel spacing in the vertical/horizontal directions (mrad)""" beam_axis: (float, float) - """Beam axis location in pixels.""" + """Beam axis location in pixels""" + rotation: float + """Rotation angle (radians)""" def to_dict(self) -> dict: return { "pixels": self.pixels, "pixel_spacing": self.pixel_spacing, "beam_axis": self.beam_axis, + "rotation": self.rotation, } @classmethod @@ -27,4 +32,5 @@ class CCD: d["pixels"], d["pixel_spacing"], d["beam_axis"], + d["rotation"], ) diff --git a/ptychopath/microscope/microscope.py b/ptychopath/microscope/microscope.py index 3f4b3f180c59cc0443c0d59f9a3cec501b9e285c..09b0c4e6afb5ed56a4cd140a9abbf5a12b50136d 100644 --- a/ptychopath/microscope/microscope.py +++ b/ptychopath/microscope/microscope.py @@ -7,6 +7,7 @@ from torch.nn import functional as F from .aberrations import Aberrations from .ccd import CCD +from ..regrid import grid_sample_complex, similarity2d_grid, similarity2d from dataclasses import dataclass import logging @@ -52,7 +53,7 @@ class Microscope: def electron_wavelength(self) -> float: r"""De Broglie wavelength with relativistic correction in Angstrom. - Calls :func:~`ptychopath.conversions.electron_wavelength`. + Calls :func:`ptychopath.microscope.conversions.electron_wavelength`. """ from .conversions import electron_wavelength @@ -67,6 +68,43 @@ class Microscope: CCD.from_dict(d["ccd"]), ) + def cbed_resampling_grid(self, plane_shape, pixel_spacing): + lam = self.electron_wavelength() + # psi_f grid spacing in mrad + fov = torch.tensor( + [ + plane_shape[-2] * pixel_spacing[-2], + plane_shape[-1] * pixel_spacing[-1], + ] + ) + # compute pixel spacing of psi_f in milliradians + dtheta_psi_f = 1e3 * lam / fov + # scaling is ratio of dtheta from CCD and exit wave computation + # warn if this is too big or too small, indicating possible losses due + # to interpolation + scale = torch.as_tensor(self.ccd.pixel_spacing) / dtheta_psi_f + # scale *= torch.as_tensor(self.ccd.pixels) / torch.as_tensor(psi_f.shape[-2:]) + if not self._scale_warned and (min(scale) < 0.5 or max(scale) > 2.0): + logging.warning(f"CCD rescale beyond 2x encountered: {scale}") + self._scale_warned = True + + # center of input should be the fft-shifted 0, 0 pixel + # Generally this is _not_ the point (N - 1) / 2, but N // 2 + psi_center = ( + plane_shape[-2] // 2 - 0.5 * (plane_shape[-2] - 1), + plane_shape[-1] // 2 - 0.5 * (plane_shape[-1] - 1), + ) + + return similarity2d_grid( + plane_shape, + input_spacing=dtheta_psi_f, + output_spacing=self.ccd.pixel_spacing, + rotation=self.ccd.rotation, + input_center=psi_center, + output_center=self.ccd.beam_axis, + output_shape=self.ccd.pixels, + ) + def resample_cbed(self, psi_f, pixel_spacing): r"""Resample a given exit wavefunction (reciprocal) to form CCD pattern. @@ -92,15 +130,8 @@ class Microscope: \psi(x,y,z) = \mathcal{F}_{2D}\left[\psi(\cdot,\cdot, 0)\right]\left(\frac{x}{\lambda z}, \frac{y}{\lambda z}\right). - Here $\lambda$ is the de Broglie wavelength of the electron: - - .. math:: - - \begin{align} - m &= 9.10938356\times 10^{-31}\mbox{ kg} \\ - E &= \frac{p^2}{2m} \\ - \lambda &= \frac{h}{p} = \frac{h}{\sqrt{2m E}}. - \end{align} + Here $\lambda$ is the de Broglie wavelength of the electron, computed + using the electron energy in :meth:`electron_wavelength`. A CCD pixel located $(i,j)$ many pixels from the beam axis corresponds to the point $(x,y,z)=(i\Delta x_{CCD}, j\Delta y_{CCD}, z)$, which maps @@ -147,59 +178,19 @@ class Microscope: Args: psi_f: complex :class:`torch.Tensor`: The Fourier transform of the electron's wavefunction at the - bottom of the sample. + bottom of the sample, after fftshift. pixel_spacing: array_like with two float elements: The pixel spacing of the computation grid (not in reciprocal space) in units of Angstrom. Returns: - cbed: complex :class:`torch.Tensor`: + complex :class:`torch.Tensor`: The complex-valued wavefunction at the CCD panel, of same shape as the CCD readout data. """ - lam = self.electron_wavelength() - # psi_f grid spacing in mrad - fov = torch.tensor( - [ - psi_f.shape[-2] * pixel_spacing[-2], - psi_f.shape[-1] * pixel_spacing[-1], - ] + return grid_sample_complex( + psi_f, + self.cbed_resampling_grid(psi_f.shape, pixel_spacing), ) - # compute pixel spacing of psi_f in milliradians - dtheta_psi_f = fov / lam - # scaling is ratio of dtheta from CCD and exit wave computation - scale = dtheta_psi_f / torch.as_tensor(self.ccd.pixel_spacing) - # the convention in torch.affine_grid is to give values from -1 to 1 - # indicating the corners of the input image so when resizing to a - # different shape, we must use the following adjustment. - scale = scale * ( - torch.as_tensor(self.ccd.pixels) / torch.as_tensor(psi_f.shape[-2:]) - ) - if min(scale) < 0.5 or max(scale) > 2.0 and not self._scale_warned: - logging.warning(f"CCD rescale beyond 2x encountered: {scale}") - self._scale_warned = True - # Perform axial scaling and translation - g = F.affine_grid( - torch.tensor( - [ - [ - [scale[0], 0, -scale[0] * self.ccd.beam_axis[0]], - [0, scale[1], -scale[1] * self.ccd.beam_axis[1]], - ] - ] - ), - size=(1, 1, *self.ccd.pixels), - align_corners=False, - ).to(psi_f.device) - - psi_f = psi_f.unsqueeze(0) - - if psi_f.is_complex(): - # sample each component separately then recombine - cbed_r = F.grid_sample(psi_f.real, g, align_corners=False) - cbed_i = F.grid_sample(psi_f.imag, g, align_corners=False) - return torch.complex(cbed_r, cbed_i) - else: - return F.grid_sample(psi_f, g, align_corners=False) def probe_fourier(self, shape, grid_spacing): """ @@ -212,7 +203,7 @@ class Microscope: grid_spacing: 2D array_like of floats: Size of a pixel in Angstrom. Returns: - probe_f: complex torch.Tensor: + complex torch.Tensor: A 2D complex-valued tensor representing the Fourier transform of a probe centered at the (0, 0) pixel. """ diff --git a/ptychopath/multislice.py b/ptychopath/multislice.py index 2bb7250545c49ab578e151bb755f8feed283712f..f042178d8b563658ebfec60cd80a1bea07f88101 100644 --- a/ptychopath/multislice.py +++ b/ptychopath/multislice.py @@ -56,7 +56,9 @@ def multislice_direct(P, probe_f, V, interaction_constant, bw_mask=None): psi = torch.fft.ifft2(psi_f) psi = psi * T psi_f = torch.fft.fft2(psi) - psi_f = psi_f * P + if z < V.shape[-3] - 1: + # skip propagator on last slice + psi_f = psi_f * P return psi_f @@ -111,9 +113,12 @@ class Multislice(nn.Module): bandwidth_limit=bw_limit, ), ) - self.register_buffer( - "bw_mask", bandwidth_limiting_fourier_mask(plane_shape, bw_limit) - ) + if bw_limit is None: + self.bw_mask = None + else: + self.register_buffer( + "bw_mask", bandwidth_limiting_fourier_mask(plane_shape, bw_limit) + ) def forward(self, probe_f, V, probe_shift=None): """ diff --git a/ptychopath/potential/fdes.py b/ptychopath/potential/fdes.py index b92ecca85c541197ff2a5e04add048b812009f81..3cd6538370077ad1b0cc4078ee95939c4d7073a0 100644 --- a/ptychopath/potential/fdes.py +++ b/ptychopath/potential/fdes.py @@ -11,6 +11,7 @@ from .potential import Potential from . import kirkland import logging +import math class SplatPositionsFunction(torch.autograd.Function): @@ -227,10 +228,10 @@ class FDESPotential(Potential): w = torch.zeros((1, x.shape[-4] + 1, *self.coarse_voxels), dtype=torch.float32) # Start with mostly vacuum - r = 0.001 # beginning weights for non-vacuum params + r = torch.tensor(0.0001) # beginning weights for non-vacuum params # The following solves for non-vacuum logits such that weight=r - l = torch.log(r) - torch.log(len(atomic_species)) - torch.log1m(r) - w[:, :-1, ...] = l + l = torch.log(r / ((1 - r) * len(atomic_species))) + w[:, :-1, ...] = l.to(w.device) self.species_weight_logits = nn.Parameter(w) self.atomic_positions = nn.Parameter( torch.zeros((1, 3, *self.coarse_voxels), dtype=torch.float32) diff --git a/ptychopath/potential/kirkland.py b/ptychopath/potential/kirkland.py index 221beb148983607cd6f1ff68cd62195de340b009..d9bb0392ee843652df642eb5e772e42ca3d97785 100644 --- a/ptychopath/potential/kirkland.py +++ b/ptychopath/potential/kirkland.py @@ -68,6 +68,13 @@ def from_kirkland_table(t): atom_parameters = { + "C": from_kirkland_table( + [ + [2.12080767e-001, 2.08605417e-001, 1.99811865e-001, 2.08610186e-001], + [1.68254385e-001, 5.57870773e000, 1.42048360e-001, 1.33311887e000], + [3.63830672e-001, 3.80800263e000, 8.35012044e-004, 4.03982620e-002], + ] + ), # Carbon, Z=6, chisq=0.102440 "O": from_kirkland_table( [ [3.39969204e-001, 3.81570280e-001, 3.07570172e-001, 3.81571436e-001], @@ -89,6 +96,13 @@ atom_parameters = { [1.89002347e-001, 3.98427790e-001, 2.29619589e-001, 9.01419843e-001], ] ), # Copper, Z=29, chisq=0.010467 + "Cr": from_kirkland_table( + [ + [1.34348379e000, 1.25814353e000, 5.07040328e-001, 1.15042811e001], + [4.26358955e-001, 8.53660389e-002, 1.17241826e-002, 6.00177061e-002], + [5.11966516e-001, 1.53772451e000, 3.38285828e-001, 6.62418319e-001], + ] + ), # Chromium, Z=24, chisq=0.015287 "Sr": from_kirkland_table( [ [1.37373086e-002, 1.87469061e-002, 1.97548672e000, 6.36079230e000], diff --git a/ptychopath/regrid.py b/ptychopath/regrid.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c239c80fbd24e5e5b5d01e7c76dd8323494b23 --- /dev/null +++ b/ptychopath/regrid.py @@ -0,0 +1,234 @@ +""" +Interpolation helper methods. +""" +import torch +from torch.nn import functional as F + +from typing import Tuple, Union + + +def grid_sample_complex(image, grid, align_corners=False, **kwargs): + """ + Resample an image along a grid, gracefully handling complex dtypes. + """ + # types of grid and image must match + float_type = image.real.dtype if image.is_complex() else image.dtype + + grid = grid.type(float_type).to(image.device) + + if image.is_complex(): + # sample each component separately then recombine + out_r = F.grid_sample(image.real, grid, align_corners=align_corners, **kwargs) + out_i = F.grid_sample(image.imag, grid, align_corners=align_corners, **kwargs) + return torch.complex(out_r, out_i) + else: + return F.grid_sample(image, grid, align_corners=align_corners, **kwargs) + + +def affine2d_grid( + input_shape: torch.Tensor, M: torch.Tensor, output_shape: Tuple[int] = None +): + r""" + Create a resampling grid for an affine transform provided as a 2x3 matrix + + Args: + image: torch.Tensor: + The image to be resampled of shape NCHW. Can be complex. + M: torch.Tensor: + Matrix of size 2x3 mapping output pixels to input pixel positions. + This matrix uses HW ordering of axes; note this matches `image` + dimension ordering but does not match the convention used by + :func:`F.affine_grid`. + output_shape: Tuple[int]: + The requested size of the output image. Defaults to same size as + input image. + """ + # drop the last row if augmented coord matrix is provided + M = M[..., :2, :] + + # reorder axes from YX to XY to prepare for affine_grid + B = torch.tensor( + [ + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], + ], + device=M.device, + dtype=M.dtype, + ) + M = (B[:2, :2] @ M @ B).view(1, 2, 3).repeat((input_shape[0], 1, 1)) + + return F.affine_grid( + M, + size=(*input_shape[:2], *output_shape[-2:]), + align_corners=False, + ) + + +def similarity2d_grid( + input_shape, + input_spacing, + output_spacing, + rotation, + input_center=(0.0, 0.0), + output_center=(0.0, 0.0), + output_shape=None, +): + r""" + Create resampling grid by scaling, rotating, and translating. + + Scaling and rotating is done about the center of the input image, while + translations are determined automatically such that the center of the input + image is placed exactly at the `center_output_pix` location. + + Args: + input_shape: Tuple[int]: + Shape of the input image. + input_spacing: Tuple[float]: + Size of input pixels in units of length. + output_spacing: Tuple[float]: + Size of output pixels in units of length matching those of + `input_spacing`. + rotation: float: + Amount to rotate in radians. This rotates output pixels + counterclockwise for positive arguments, leading to an apparent + _clockwise_ rotation of the input image. + input_center: Tuple[float]: + Position of the center of the input image, in + units of pixels, relative to the fractional pixel (N - 1) / 2 in + each dimension. Defaults to (0.0, 0.0). + output_center: Tuple[float]: + Position of the center of the input image in the output image, in + units of pixels, relative to a (0, 0) point at the center of the + output image, i.e. the fractional pixel (N - 1) / 2 in each + dimension. Defaults to (0.0, 0.0). + output_shape: Tuple[int]: + Shape of the output image. If omitted, defaults to shape of input + image. + """ + # Starting with an output pixel, we compose the following steps in this + # order: + # 1. transform from (-1, 1) to output pixels + # 2. subtract center output pixel location + # 3. scale from pixels to world coords + # 4. rotate + # 5. convert from world coords to input pixels. + # 6. add center input pixel location + # 7. transform from input pixels to (-1, 1) + + ## Conversions between pixels and (-1, 1) range + # Note that mesh_grid and affine_grid expect a parameter we'll call q in the + # range (-1, 1) which is equivalent to pixel coords that lie in [0, N-1]. + # That is: + # q = -1 + 2 * (p ) / (N - 1) # align_corners=True (default) + # = pz * 2 / (N - 1) # where pz = p - (N - 1) / 2 + # q = -1 + 2 * (p + 0.5) / N # align_corners=False + # = pz * 2 / N + # The inverse transformation is: + # pz = q * (N - 1) / 2 # align_corners=True (default) + # pz = q * N / 2 # align_corners=False + # We will work in q and pz coords, and not explicitly handle the midpoint + # of the image which is always (0, 0) or (N - 1) / 2 in these + # parametrizations, respectively + Pin = torch.tensor( # map from q to pz + [ + [input_shape[-2] * 0.5, 0, 0], + [0, input_shape[-1] * 0.5, 0], + [0, 0, 1], + ], + dtype=torch.float64, + ) + Pout = torch.tensor( + [ + [output_shape[-2] * 0.5, 0, 0], + [0, output_shape[-1] * 0.5, 0], + [0, 0, 1], + ], + dtype=torch.float64, + ) + + ## Shifting the center of each image, in pixels + # These center the point at pz=(0, 0) instead of given center + Tin = torch.tensor( + [ + [1, 0, -input_center[0]], + [0, 1, -input_center[1]], + [0, 0, 1], + ], + dtype=torch.float64, + ) + Tout = torch.tensor( + [ + [1, 0, -output_center[0]], + [0, 1, -output_center[1]], + [0, 0, 1], + ], + dtype=torch.float64, + ) + + # scale comes directly from pixel spacing differences + Sin = torch.diag( + torch.tensor([input_spacing[0], input_spacing[1], 1.0], dtype=torch.float64) + ) + Sout = torch.diag( + torch.tensor([output_spacing[0], output_spacing[1], 1.0], dtype=torch.float64) + ) + + # Rotation + r = torch.as_tensor(rotation, dtype=torch.float64) + c = torch.cos(r) + s = torch.sin(r) + R = torch.tensor( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + # Now we compose all steps together in the correct order + # S @ T @ P converts each q point to pixel coords, translates it to adjust + # for the center given in pixels, then scales to world coordinates using the + # pixel scaling. We do this to the output pixels, rotate them, then invert + # the procedure to map to q coordinates in input space. + STPin_inv = torch.linalg.inv(Sin @ Tin @ Pin) + M = STPin_inv @ R @ (Sout @ Tout @ Pout) + + return affine2d_grid(input_shape, M, output_shape) + + +def similarity2d( + image, + input_spacing, + output_spacing, + rotation, + input_center=(0.0, 0.0), + output_center=(0.0, 0.0), + output_shape=None, + **kwargs, +): + input_dim = image.ndim + assert image.ndim <= 4 + while image.ndim < 4: + image = image.unsqueeze(0) + + if output_shape is None: + output_shape = image.shape[-2:] + + g = similarity2d_grid( + image.shape, + input_spacing, + output_spacing, + rotation, + input_center=input_center, + output_center=output_center, + output_shape=output_shape, + ) + + interped = grid_sample_complex(image, g, align_corners=False, **kwargs) + + # adjust output to match input so we can ignore batch and channel + while interped.ndim > input_dim: + interped = interped.squeeze(0) + + return interped diff --git a/tests/test_aperture.py b/tests/test_aperture.py new file mode 100644 index 0000000000000000000000000000000000000000..dee428c8ad4e8d190105d526518b30390bbeed4d --- /dev/null +++ b/tests/test_aperture.py @@ -0,0 +1,21 @@ +""" +Test aperture estimation and calibration. +""" +import pyms +import pytest +import torch + +from .fixtures.random import set_seeds + +from ptychopath.microscope import calibration + +set_seeds() + + +def test_calibrate_aperture(): + """Test whether the code runs for aperture estimation""" + cbed_av = torch.randn(15, 12) + sp = (0.2, 0.5) + ap = calibration.aperture.calibrate_aperture( + cbed_av, pixel_spacing=sp, threshold=0.3 + ) diff --git a/tests/test_dpc.py b/tests/test_dpc.py new file mode 100644 index 0000000000000000000000000000000000000000..fd78b232ade2b59ebc98e0549b80c9d7e1fbfe79 --- /dev/null +++ b/tests/test_dpc.py @@ -0,0 +1,22 @@ +"""Test DPC-based calibration methods""" +import pyms +import pytest +import torch + +from .fixtures.random import set_seeds + +from ptychopath.microscope.calibration import dpc + +set_seeds() + + +def test_calibrate_rotation(): + """Test whether the code runs for rotation calibration""" + com = torch.randn(2, 15, 12) + dpc.calibrate_rotation(com) + + +def test_com2phase(): + """Test whether the code runs for Wiener deconvolution phase estimation""" + com = torch.randn(2, 15, 12) + phase = dpc.com2phase(com, spacing=(0.15, 0.13)) diff --git a/tests/test_microscope.py b/tests/test_microscope.py index e85d43efc493d60752a7b38adca00cf8a045e8bc..8b16b1dd5f46815e97ffc035954114fc1b80a833 100644 --- a/tests/test_microscope.py +++ b/tests/test_microscope.py @@ -55,6 +55,7 @@ def test_ccd(): "pixels": [512, 513], "pixel_spacing": [0.2, 0.35], "beam_axis": [256.3, 255.47], + "rotation": 0.1, } ccd = CCD.from_dict(d) ccdd = ccd.to_dict() @@ -66,6 +67,7 @@ def test_microscope(): "pixels": [512, 513], "pixel_spacing": [0.2, 0.35], "beam_axis": [256.3, 255.47], + "rotation": 0.1, } dab = { # use different values to detect typos "C1": 0.11, # Defocus (real) @@ -111,6 +113,7 @@ def test_resample_cbed(): "pixels": [32, 31], "pixel_spacing": [0.1, 0.1], "beam_axis": [256.3, 255.47], + "rotation": 0.1, }, "aberrations": { # use different values to detect typos "C1": 0.11, # Defocus (real) @@ -134,7 +137,7 @@ def test_resample_cbed(): ) # make up a complex exit wavefunction (FFT) - exitwave_f = torch.randn((1, 16, 17), dtype=torch.complex64) + exitwave_f = torch.randn((1, 1, 16, 17), dtype=torch.complex64) grid_spacing = (100.0, 0.1, 0.1) # grid in angstrom exitwave_f_intensity = exitwave_f.abs() ** 2 @@ -179,6 +182,7 @@ def test_make_probe(): "pixels": [32, 31], "pixel_spacing": [0.1, 0.1], "beam_axis": [256.3, 255.47], + "rotation": 0.1, }, "aberrations": { # use different values to detect typos "C1": 0.11, # Defocus (real) diff --git a/tests/test_multislice.py b/tests/test_multislice.py index 451df9281482600194e0ba355eea94b307acfd11..efd7fee75b40adb82788af1e43ac6ed58ea0bec0 100644 --- a/tests/test_multislice.py +++ b/tests/test_multislice.py @@ -51,6 +51,29 @@ def test_module(device): assert psi_f.dtype == probe_f.dtype +@pytest.mark.parametrize("device", devices) +def test_module_nobwmask(device): + # make up a probe, propagator, and potential + probe_f = torch.randn(3, 5, dtype=torch.complex64, device=device) + # five random potential slices + V = torch.randn(4, *probe_f.shape, dtype=torch.float32, device=device) + + grid_spacing = (0.05, 0.05, 0.05) + + m = multislice.Multislice( + plane_shape=probe_f.shape, + grid_spacing=grid_spacing, + energy=300.0, + bw_limit=None, + ) + m = m.to(device) + psi_f = m(probe_f, V) + # also test with shift + psi_f = m(probe_f, V, probe_shift=(1.4, -2.3)) + + assert psi_f.dtype == probe_f.dtype + + @pytest.mark.parametrize( "device", [ diff --git a/tests/test_regrid.py b/tests/test_regrid.py new file mode 100644 index 0000000000000000000000000000000000000000..220c7123eaab95b721920e07ca999500eb6593e2 --- /dev/null +++ b/tests/test_regrid.py @@ -0,0 +1,63 @@ +import numpy as np +import pyms +import pytest +import torch + +from .fixtures.random import set_seeds + +from ptychopath import regrid + +set_seeds() + + +def test_regrid(): + im = torch.randn(2, 2, 5, 4) + insp = 0.2, 0.2 + outsp = 0.3, 0.1 + outsh = [3, 6] + + out = regrid.similarity2d(im, insp, outsp, rotation=np.pi / 2, output_shape=outsh) + assert list(out.shape[-2:]) == outsh + + +def test_similarity_identity(): + """Test that an identity transform on different grids works""" + insh = [5, 3] + im = torch.rand(*insh) + insp = 0.2, 0.2 + outsp = 0.2, 0.2 + outsh = [3, 3] + + out = regrid.similarity2d( + im.view(1, 1, *insh), insp, outsp, rotation=0.0, output_shape=outsh + ).squeeze() + + assert torch.allclose(out, im[1:-1, :]) + + +def test_similarity_rot90(): + """Test that rotating by 90 degrees works as expected""" + insh = [5, 3] + im = torch.rand(*insh) + insp = 0.2, 0.2 + outsp = 0.2, 0.2 + outsh = [3, 5] + + out = regrid.similarity2d( + im.view(1, 1, *insh), insp, outsp, rotation=-np.pi / 2, output_shape=outsh + ).squeeze() + + assert torch.allclose(out, torch.rot90(im)) + + +def test_similarity_complex(): + """Test that regridding complex images works""" + im = torch.randn(2, 2, 5, 4, 2) + im = torch.view_as_complex(im) + insp = 0.2, 0.2 + outsp = 0.3, 0.1 + outsh = [3, 6] + + out = regrid.similarity2d(im, insp, outsp, rotation=np.pi / 2, output_shape=outsh) + assert out.dtype == im.dtype + assert list(out.shape[-2:]) == outsh