Commit a62412a9 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

Adding first-principles scattering potential integration and some docs

parent 08c03d91
Loading
Loading
Loading
Loading
+140 −27
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from .utils import *
from .cuda_kernels import ProbeKernels, PotentialKernels
import numpy as np
from scipy.special import k0
from scipy import integrate
import multiprocessing as mp
import ctypes
import sys
@@ -57,7 +58,11 @@ def setup_device(gpu_id=0):
    atexit.register(_clean_up)
    return ctx 

class MSA(object):
class MSA:
    '''
    Base Class Implementation of the Multi-Slice Algorithm. 
    Potential Construction and Beam Propagation are done in parallel (multiprocessing) and pyfftw (if available) is used for batched 2-D FFTs.
    '''
    def __init__(self, energy, semi_angle, supercell, sampling=np.array([512, 512]), max_angle=None, verbose=False,
                 debug=False, output_dir='', debye_waller=True):
        self.E = energy
@@ -103,12 +108,13 @@ class MSA(object):

        # Set simulation parameters
        self.dims = self.supercell_xyz.max(0) - self.supercell_xyz.min(0)
        if max_angle is None:
        self.max_angle = max_angle
        if self.max_angle is None:
            self.sampling = sampling
            self.kmax = np.min(self.sampling/self.dims[:2])
            self.max_ang = self.kmax * self.Lambda
        else:
            self.max_ang = max_angle
            self.max_ang = self.max_angle
            self.kmax = self.max_ang / self.Lambda
            self.sampling = np.floor(self.kmax * self.dims[:2]).astype(np.int)
            self.sampling = self.sampling[::-1]
@@ -117,7 +123,7 @@ class MSA(object):
        self.sigma = sigma_int(self.E*1e3)
        self.print_verbose('Simulation Parameters:\nSupercell dimensions xyz:%s (Å)\nReal, Reciprocal space pixel sizes:%s Å, %s 1/Å'
              '\nMax angle: %2.2f (rad)\nSampling in real and reciprocal space: %s pixels,\nThermal Effects: %s' %
              (format(np.round(self.dims, 2)), format(np.round(self.pix_size, 2)), format(np.round(self.kpix_size, 2)),
              (format(np.round(self.dims, 3)), format(np.round(self.pix_size, 3)), format(np.round(self.kpix_size, 3)),
               self.max_ang, format(self.sampling), format(self.debye_waller)))

    def print_debug(self, *args, **kwargs):
@@ -216,14 +222,16 @@ class MSA(object):
        # probe wavefunction
        psi_k = aperture * phase_error
        psi_k = psi_k.astype(np.complex64)
        y, x = probe_position
        x, y = probe_position
        kr = k_x * x + k_y * y
        phase_shift = np.exp(2 * np.pi * 1.j * kr).astype(np.complex64)
        if pyfftw is not None:
            psi_x = pyfftw.interfaces.numpy_fft.ifft2(psi_k * phase_shift)
            psi_x = pyfftw.interfaces.numpy_fft.fftshift(psi_x)
        # TODO: need to make fft library choice optional
        # psi_x = np.fft.ifft2(psi_k * phase_shift, norm='ortho')
        # psi_x = np.fft.fftshift(psi_x)
        else:
        # fall back on numpy 
            psi_x = np.fft.ifft2(psi_k * phase_shift, norm='ortho')
            psi_x = np.fft.fftshift(psi_x)
        psi_x /= np.sqrt(np.sum(np.abs(psi_x) ** 2))
        return psi_x.astype(np.complex64), psi_k.astype(np.complex64), aperture.astype(np.float32)

@@ -275,9 +283,12 @@ class MSA(object):
            warn('Probe wave function must be initialized first before calling multi_slice')
            return
        if probe_grid:
            if pyfftw is not None:
                # Define heuristics for python multiprocessing and FFTW multi-threading
                self.fftw_threads = int(self.sampling.max() // 512)
                processes = min(mp.cpu_count(), self.probe_positions.shape[0]) // self.fftw_threads
            else:
                processes = min(mp.cpu_count(), self.probe_positions.shape[0]) 
            chunk = int(np.floor(self.probe_positions.shape[0] / processes))


@@ -318,6 +329,7 @@ class MSA(object):
            slices = self.potential_slices
        slices = np.exp(1.j * self.sigma * slices).astype(np.complex64)

        if pyfftw is not None:
            for (i, trans) in enumerate(slices[::-1]):
                t_psi = pyfftw.byte_align(trans * probe_last)
                fft_fwd = pyfftw.builders.fft2(t_psi, threads=self.fftw_threads, avoid_copy=True)
@@ -326,6 +338,15 @@ class MSA(object):
                probe_last = fft_bwd()
                if save_probes:
                    probes.append(probe_last)
        else:
            for (i, trans) in enumerate(slices[::-1]):
                t_psi = trans * probe_last
                fft_fwd = np.fft.fft2(t_psi)
                temp = fft_fwd * blim_mask * propag
                fft_bwd = np.fft.ifft2(temp)
                probe_last = fft_bwd
                if save_probes:
                    probes.append(probe_last) 
        self.print_debug('finished beam propagation.')

        if save_probes:
@@ -345,8 +366,63 @@ class MSA(object):
                  'Change the sampling and/or slice thickness!')
        return prob

    def integrate_potential_slices(self, pot, grid, slice_thickness, output=False):
        '''
        Integrate a scattering potential, typically coming from all-electron ab initio simulations.
        args:
            pot: 3-d array, in units of e-/Angstrom**3. spatial order is [z,y,x]
            grid: (z,y,x), tuple of arrays with coordinates along each axis in units of Angstrom.
        '''
        
        self.slice_t = slice_thickness
        z_coord, y_coord, x_coord = grid
        z_sampling = z_coord[1] - z_coord[0]
        z_dim = z_coord[-1] - z_coord[0]
        num_slices = int(z_dim // self.slice_t)

        # Cropping to a square array- not necessary and should be removed later after cuda kernel tests 
        max_dim = min(x_coord.size, y_coord.size) 
        x_resamp = x_coord[:max_dim]
        y_resamp = y_coord[:max_dim]
        pot = pot[:, :max_dim, :max_dim]

        # Cropping potential array along z-dir to get int number of slices
        d_slice = int(pot.shape[0]//num_slices)
        if d_slice != int(np.ceil(pot.shape[0]/num_slices)):
            pot_resamp = pot[:d_slice * num_slices, :, :]
            z_resamp = z_coord[:d_slice * num_slices]
        else:
            pot_resamp = pot 
            z_resamp = z_coord

        # Integrating
        pot_slices = np.split(pot_resamp, num_slices, axis=0)
        zcoord_slices = np.split(z_resamp, num_slices, axis=0)
        pot_slices = np.array([integrate.trapz(pot_slice, x=z_coord, axis=0) for pot_slice, z_coord in zip(pot_slices, zcoord_slices)])
        self.potential_slices = - pot_slices

        # Update sim params
        ## TODO: below doesn't take into account user-specified max_ang
        self.num_slices = self.potential_slices.shape[0]
        self.dims = np.array([x_coord.max(), y_coord.max(), z_coord.max()])
        if self.max_angle is None:
            self.sampling = np.array(self.potential_slices.shape[1:]) 
            self.kmax = np.min(self.sampling/self.dims[:2])
            self.max_ang = self.kmax * self.Lambda
        else:
            self.max_ang = self.max_angle
            self.kmax = self.max_ang / self.Lambda
            self.sampling = np.floor(self.kmax * self.dims[:2]).astype(np.int)
            self.sampling = self.sampling[::-1]
        self.pix_size = self.dims[:2][::-1] / self.sampling
        self.kpix_size = self.kmax/self.sampling
        if output:
            return - pot_slices

class MSAHybrid(MSA):
    '''
    Class that performs potential building on CPU and beam propagation on GPU using scikit-cuda cufft interface. 
    '''
    def plan_simulation(self, num_probes=None):
        if num_probes is None:
            num_probes = self.num_probes
@@ -431,8 +507,31 @@ class MSAHybrid(MSA):
        pool.join()
        return probes


class MSAGPU(MSAHybrid):
    '''
    Class that implements MSA on a single GPU. 
    Potential Building + Beam Propagation are done on the GPU using JIT-compiled CUDA-C Kernels.
    '''

    def setup_device(self, gpu_id=0):
        global ctx
        cuda.init()
        dev = cuda.Device(gpu_id)
        ctx = dev.make_context()

        import atexit
        def _clean_up():
            if ctx is not None:
                try:
                    ctx.pop()
                    ctx.detach()
                except Exception as e:
                    warn(format(e))
            from pycuda.tools import clear_context_caches
            clear_context_caches()
        atexit.register(_clean_up)
        self.ctx = ctx
        return ctx 

    @staticmethod
    def clean_up(ctx=None, vars=None):
@@ -447,6 +546,17 @@ class MSAGPU(MSAHybrid):
            from pycuda.tools import clear_context_caches
            clear_context_caches()

    def integrate_potential_slices(self, pot, grid, slice_thickness):
        pot_slices = super(MSAGPU, self).integrate_potential_slices(pot, grid, slice_thickness, output=True)
        pot_slices = np.exp(1.j * self.sigma * pot_slices).astype(np.complex64)
        self.potential_slices = cuda.register_host_memory(pot_slices)
        potential_slices_d = cuda.to_device(self.potential_slices)

        # store device allocation ref for later use
        self.pot_dev_ptr = potential_slices_d
        self.vars = []
        self.vars.append(self.potential_slices.base)

    def build_potential_slices(self, ctx, slice_thickness):
        self.ctx = ctx
        # find number of slices and atomic sites per slice
@@ -767,8 +877,8 @@ class MSAGPU(MSAHybrid):
               del probe_d, norm_const, plan
            self.ctx.synchronize()
            self.print_verbose('finished simulation phase #%d' % i)
            self.probes /= self.normalization
            # self.probes[phase][batch] = self.probes[phase][batch]/self.normalization
            # self.probes /= self.normalization
            self.probes[phase][batch] = self.probes[phase][batch]/self.normalization
            sim_t = time()-t
            self.print_verbose('Propagated %d probes in %2.4f s' % (self.probe_positions[phase].shape[0], sim_t))
        
@@ -855,8 +965,11 @@ class MSAGPU(MSAHybrid):
        # ctx.synchronize()
        cuda.memcpy_dtoh_async(psi_x_pos_pin, psi_pos_d, stream=stream)


class MSAMPI(MSAGPU):
    '''
    Class that extends MSAGPU to multiple GPUs and/or nodes.
    The parallel distribution falls back on hdf5 writing of the results if the memory associated with the MPI root rank is exceeded.
    '''
    def __init__(self, *args, **kwargs):
        #global comm
        comm = MPI.COMM_WORLD
+1 −1
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ setup(
    author='Numan Laanait',
    author_email='laanaitn@ornl.gov',
    description='',
    install_requires=['scipy', 'pymatgen', 'numpy', 'pycuda==2019.1', 'scikit-cuda', 'mpi4py'],
    install_requires=['scipy', 'pymatgen', 'numpy', 'pycuda>=2019.1', 'scikit-cuda', 'mpi4py'],
    #install_requires=['numpy', 'scipy', 'pymatgen', 'pybtex'],
    test_suite='tests',
    python_requires='>=3.6',