Commit a9ff105d authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

fixing batch issue and grid generation cuda kernels

parent c00e2b32
Loading
Loading
Loading
Loading
+49 −23
Original line number Diff line number Diff line
@@ -223,7 +223,7 @@ class MSA:
        psi_k = aperture * phase_error
        psi_k = psi_k.astype(np.complex64)
        x, y = probe_position
        kr = k_x * x + k_y * y
        kr = -k_x * x - k_y * y
        phase_shift = np.exp(2 * np.pi * 1.j * kr).astype(np.complex64)
        if pyfftw is not None:
            psi_x = pyfftw.interfaces.numpy_fft.ifft2(psi_k * phase_shift)
@@ -250,24 +250,28 @@ class MSA:
        bl_mask = np.heaviside(max(arr_shape[0], arr_shape[1]) * radius - r_grid, 0)
        return bl_mask.astype(np.float32)

    def generate_probe_positions(self, probe_step=np.array([0.1, 0.1]), probe_range=np.array([[0., 1.0], [0., 1.0]]), grid_steps=None):
    def generate_probe_positions(self, probe_step=np.array([0.1, 0.1]), probe_range=np.array([[0., 1.0], [0., 1.0]]), 
            grid_steps=None, fraction=0.5, origin = np.array([0,0])):
        if grid_steps is not None:
            grid_range_start = (0.5 - self.dims[:2]/2)/2
            grid_range_stop = (0.5 + self.dims[:2]/2)/2
            grid_range_start = (0.5 + origin[::-1] * self.dims[:2]  - self.dims[:2] * fraction)/2
            grid_range_stop = (0.5 + origin[::-1] * self.dims[:2] + self.dims[:2]  * fraction)/2
            # grid_range_start = (0.5 - self.dims[:2]/2)/2
            # grid_range_stop = (0.5 + self.dims[:2]/2)/2
            x_pos, y_pos = np.mgrid[grid_range_start[0]:grid_range_stop[0]:-1j*grid_steps[0], 
                                    grid_range_start[1]:grid_range_stop[1]:-1j*grid_steps[1]]
            grid_steps_x, grid_steps_y = grid_steps.astype(np.int) 
            grid_range_x = np.array([grid_range_start[0], grid_range_stop[0]])
            grid_range_y = np.array([grid_range_start[1], grid_range_stop[1]])
            probe_pos = np.array([[-x, y] for y, x in zip(y_pos.flatten()[::-1], x_pos.flatten())])
            probe_pos = np.array([[y, x] for y, x in zip(y_pos.flatten(), x_pos.flatten())])
            # probe_pos = np.array([[x, y] for y, x in zip(y_pos.flatten(), x_pos.flatten())])

        else:                            
            grid_steps_x, grid_steps_y = np.floor(np.diff(probe_range).flatten() * self.dims[:2] / probe_step).astype(np.int)
            grid_range_x, grid_range_y = [(probe_range[i] - np.ones((2,)) * 0.5) * self.dims[i]
                                        for i in range(2)]
            y_pos, x_pos = np.mgrid[grid_range_y[0]: grid_range_y[1]: -1j * grid_steps_y,
                        grid_range_x[0]: grid_range_x[1]: -1j * grid_steps_x]
            probe_pos = np.array([[y, -x] for y, x in zip(y_pos.flatten()[::-1], x_pos.flatten())])
            x_pos, y_pos = np.mgrid[grid_range_x[0]: grid_range_x[1]: -1j * grid_steps_x,
                        grid_range_y[0]: grid_range_y[1]: -1j * grid_steps_y]
            probe_pos = np.array([[y, x] for y, x in zip(y_pos.flatten(), x_pos.flatten())])
        self.grid_steps = np.array([grid_steps_x, grid_steps_y])
        self.grid_range = np.array([grid_range_x, grid_range_y]).flatten()
        self.probe_positions = probe_pos
@@ -310,8 +314,8 @@ class MSA:
            trans_probes = self.propagate_beam([None, probe_pos, save_probes, probe_grid, bandwidth])

        self.print_verbose('Propagated %d probe wavefunctions' % trans_probes.shape[0])
        self.trans_probes = trans_probes
        return self.trans_probes
        self.probes = trans_probes
        return trans_probes

    def propagate_beam(self, args):
        probe_num, probe_pos, save_probes, probe_grid, bandwidth = args
@@ -357,7 +361,7 @@ class MSA:
        return probe_last

    def check_simulation(self):
        prob = np.sum([np.abs(probe)**2 for probe in self.trans_probes], axis=(1, 2))
        prob = np.sum([np.abs(probe)**2 for probe in self.probes], axis=(1, 2))
        max_val = prob.max()
        min_val = prob.min()
        print('Max (Min) Integrated Intensity: %2.2f (%2.2f)' % (max_val, min_val))
@@ -419,6 +423,21 @@ class MSA:
        if output:
            return - pot_slices
    
    def integrate_cbed(self, detector_array=None, detector_params={'inner_angle':50e-3, 'outer_angle':100e-3}):
        if detector_array is None:
            assert detector_params['inner_angle'] < detector_params['outer_angle'] and detector_params['outer_angle'] < self.max_ang, \
                print('Detector angles exceed maximum scattering angle simulated and/or values are not consistent') 
            inner_radius = detector_params['inner_angle'] / (self.kpix_size * self.Lambda * self.sampling) / 2
            outer_radius = detector_params['outer_angle'] / (self.kpix_size * self.Lambda * self.sampling) / 2
            print(inner_radius, outer_radius) 
            inner_segment = self.bandwidth_limit_mask(self.sampling, radius=inner_radius[0]).astype(np.bool)
            outer_segment = self.bandwidth_limit_mask(self.sampling, radius=outer_radius[0]).astype(np.bool)
            detector_array = np.logical_not(outer_segment == inner_segment)
        new_shape = list(self.grid_steps) + list(self.sampling)
        self.probes = self.probes.reshape(new_shape)
        intgr_cbed = np.sum(self.probes * detector_array, axis=(2,3))
        return intgr_cbed, detector_array

class MSAHybrid(MSA):
    '''
    Class that performs potential building on CPU and beam propagation on GPU using scikit-cuda cufft interface. 
@@ -546,19 +565,20 @@ class MSAGPU(MSAHybrid):
            from pycuda.tools import clear_context_caches
            clear_context_caches()

    def integrate_potential_slices(self, pot, grid, slice_thickness):
    def integrate_potential_slices(self, pot, grid, slice_thickness, output=False):
        pot_slices = super(MSAGPU, self).integrate_potential_slices(pot, grid, slice_thickness, output=True)
        pot_slices = np.exp(1.j * self.sigma * pot_slices).astype(np.complex64)
        self.potential_slices = cuda.register_host_memory(pot_slices)
        pot_slices_phase = np.exp(1.j * self.sigma * pot_slices).astype(np.complex64)
        self.potential_slices = cuda.register_host_memory(pot_slices_phase)
        potential_slices_d = cuda.to_device(self.potential_slices)

        # store device allocation ref for later use
        self.pot_dev_ptr = potential_slices_d
        self.vars = []
        self.vars.append(self.potential_slices.base)
        if output:
            return pot_slices

    def build_potential_slices(self, ctx, slice_thickness):
        self.ctx = ctx
    def build_potential_slices(self, slice_thickness):
        # self.ctx = ctx
        # find number of slices and atomic sites per slice
        self.slice_t = slice_thickness
        self.num_slices = np.int32(np.floor(self.dims[-1] / slice_thickness))
@@ -779,7 +799,6 @@ class MSAGPU(MSAHybrid):
        self.print_debug('block, grid:', block_3d, grid_3d)

        # allocate memory
        # self.probes = np.empty((self.num_probes, shape_y, shape_x), dtype=np.complex64)
        self.propag = cuda.aligned_zeros((int(self.sampling[0]), int(self.sampling[1])), np.complex64)
        self.vars.append(self.propag.base)
        self.propag = cuda.register_host_memory(self.propag)
@@ -805,7 +824,6 @@ class MSAGPU(MSAHybrid):
        ones_d = cuda.mem_alloc(ones.nbytes)
        cuda.memcpy_htod_async(ones_d, ones, cuda.Stream())
        

        # grab needed kernels
        propag_func = self.kernels['propagator']
        mask_func = self.kernels['hard_aperture']
@@ -848,7 +866,7 @@ class MSAGPU(MSAHybrid):
        # 2. Generate batch fft plans and create pinned memory pointers and cuda streams
            for batch_num in range(num_batches+1):
                if batch_num == num_batches:
                    slice_obj = slice(batch_num * batch_size, None)
                    slice_obj = slice(batch_num * batch_size, phase.stop - phase.start)
                else:
                    slice_obj = slice(batch_num * batch_size, (batch_num + 1) * batch_size)
                batches.append(slice_obj)
@@ -861,10 +879,15 @@ class MSAGPU(MSAHybrid):
                probes_d.append(cuda.mem_alloc(int(num_probes*np.prod(self.sampling)*8)))
                plans.append(skfft.Plan(self.sampling, np.complex64, np.complex64, batch=num_probes, stream=stream))
                norm_consts.append(cuda.mem_alloc(np.empty(num_probes,dtype=np.float32).nbytes))
            self.print_debug('Batches: %s' % format(batches))
        # 3. Propagate Beams
            self.print_verbose("Simulating probes %d out of %d..." % (phase.stop, self.num_probes))
            for batch, stream, probe_d, plan,  norm_const in zip(batches, streams, probes_d, plans, norm_consts):
                num_probes = np.int32(self.probe_positions[phase][batch].shape[0])
                self.print_debug('batch: %s' % format(batch))
                grid_range = self.probe_positions[phase][batch].astype(np.float32) 
                grid_range_d = cuda.mem_alloc(grid_range.nbytes)
                cuda.memcpy_htod_async(grid_range_d, grid_range, stream)
                self.print_debug('batch: %s, stream: %s' % (format(batch), format(stream)))
                self.__propagate_beams(num_probes, batch, probe_d, propag_d, psi_k_d, norm_const, grid_steps_d, grid_range_d,
                                     self.probes[phase][batch], plan, ones_d, stream, transmit=transmit)
            self.ctx.synchronize()
@@ -878,7 +901,10 @@ class MSAGPU(MSAHybrid):
            self.ctx.synchronize()
            self.print_verbose('finished simulation phase #%d' % i)
            # self.probes /= self.normalization
            self.probes[phase][batch] = self.probes[phase][batch]/self.normalization
            if transmit:
                self.probes[phase] /= self.normalization
            else:
                self.probes[phase] /= np.sqrt(self.normalization)
            sim_t = time()-t
            self.print_verbose('Propagated %d probes in %2.4f s' % (self.probe_positions[phase].shape[0], sim_t))
        
+1 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@
     __global__ void BuildScatteringPotential(pycuda::complex<float> slice[][{{y_sampling}}][{{x_sampling}}],
                                  float atom_pot_stack[][{{pot_shape_y}}][{{pot_shape_x}}],
                                  float sites[][{{sites_size}}],
                                  // sites[] !!!
                                  float sigma)

     {
+53 −15
Original line number Diff line number Diff line
@@ -15,26 +15,47 @@ __inline__ __device__ int warpReduceSumSync(int val, int mask){
    return k_rad;
}

 __inline__ __device__ float phase_shift(float k_max, int size_x, int size_y, int col_idx, int row_idx, int stk_idx,
    int *grid_step, float *grid_range){
//  __inline__ __device__ float phase_shift(float k_max, int size_x, int size_y, int col_idx, int row_idx, int stk_idx,
//     int *grid_step, float *grid_range){
//     const double pi = acos(-1.0);
//     float kx = float(col_idx) * k_max/float(size_x - 1) - k_max/2.;
//     float ky = float(row_idx) * k_max/float(size_y - 1) - k_max/2.;
//     int grid_step_x = grid_step[0];
//     int grid_step_y = grid_step[1];
//     float grid_start_x = grid_range[0];
//     float grid_end_x = grid_range[1];
//     float grid_start_y = grid_range[2];
//     float grid_end_y = grid_range[3];
//     float ry_idx = rintf(floorf(stk_idx * 1.f  / grid_step_y));
//     float rx_idx = stk_idx - ry_idx * grid_step_x;
//     float ry = ry_idx * (grid_end_y - grid_start_y) / (grid_step_y - 1) + grid_start_y;
//     float rx = rx_idx * (grid_end_x - grid_start_x) / (grid_step_x - 1) + grid_start_x;
//     float kr = - kx * rx - ky * ry;
//     return kr;
// }

 __inline__ __device__ float phase_shift(float k_max, int size_x, int size_y, int col_idx, int row_idx, 
    int *grid_step, float grid_range[2]){
    const double pi = acos(-1.0);
    float kx = float(col_idx) * k_max/float(size_x - 1) - k_max/2.;
    float ky = float(row_idx) * k_max/float(size_y - 1) - k_max/2.;
    int grid_step_x = grid_step[0];
    int grid_step_y = grid_step[1];
    float grid_start_x = grid_range[0];
    float grid_end_x = grid_range[1];
    float grid_start_y = grid_range[2];
    float grid_end_y = grid_range[3];
    float ry_idx = rintf(floorf(stk_idx * 1.f  / grid_step_y));
    float rx_idx = stk_idx - ry_idx * grid_step_x;
    float ry = ry_idx * (grid_end_y - grid_start_y) / (grid_step_y - 1) + grid_start_y;
    float rx = rx_idx * (grid_end_x - grid_start_x) / (grid_step_x - 1) + grid_start_x;
    // int grid_step_x = grid_step[0];
    // int grid_step_y = grid_step[1];
    // float grid_start_x = grid_range[0];
    // float grid_end_x = grid_range[1];
    // float grid_start_y = grid_range[2];
    // float grid_end_y = grid_range[3];
    // float ry_idx = rintf(floorf(stk_idx * 1.f  / grid_step_y));
    // float rx_idx = stk_idx - ry_idx * grid_step_x;
    // float ry = ry_idx * (grid_end_y - grid_start_y) / (grid_step_y - 1) + grid_start_y;
    // float rx = rx_idx * (grid_end_x - grid_start_x) / (grid_step_x - 1) + grid_start_x;
    float rx = grid_range[1];
    float ry = grid_range[0];
    // printf(rx, ry);
    float kr = - kx * rx - ky * ry;
    return kr;
}


//TODO: 2d vectorize the indexing [idx]
__global__ void norm_const_stack(pycuda::complex<float> arr[][{{x_sampling}} * {{y_sampling}}], float *norm, int size_z) {
  float sum = 0.f;
@@ -179,6 +200,7 @@ __global__ void soft_aperture(float *arr, float k_max, float k_semi, int size_x,
    float k_rad = calc_krad(k_max, size_x, size_y, col_idx, row_idx);
    if (row_idx < size_y && col_idx < size_x)
    {
        // Note that the smoothing factor is fixed at 80 so comparison with CPU results need to match it
        arr[idx] = 1.f / (1.f + expf(- 2.f * 80.f * (k_semi - k_rad)));
    }
}
@@ -199,16 +221,32 @@ __global__ void spherical_phase_error(pycuda::complex<float> *arr, float k_max,
    }
}

// __global__ void build_probes_stack(pycuda::complex<float> psi_pos[][{{y_sampling}}][{{x_sampling}}],
//                     pycuda::complex<float> psi_k[{{y_sampling}}][{{x_sampling}}],
//                     int z_size, float k_max, int *grid_step, float *grid_range){
//     const double pi = acos(-1.0);
//     unsigned col_idx = blockIdx.x*blockDim.x + threadIdx.x;
//     unsigned row_idx = blockIdx.y*blockDim.y + threadIdx.y;
//     unsigned stk_idx = blockIdx.z*blockDim.z + threadIdx.z;
//     if (col_idx < {{x_sampling}} && row_idx < {{y_sampling}} && stk_idx < z_size)
//     {
//         float kr = phase_shift(k_max, {{x_sampling}}, {{y_sampling}}, col_idx, row_idx, stk_idx, grid_step, grid_range);
//         psi_pos[stk_idx][row_idx][col_idx]  = pycuda::complex<float>(cosf(2 * pi * kr), sinf(2 * pi * kr));
//         psi_pos[stk_idx][row_idx][col_idx] *= psi_k[row_idx][col_idx];
//     }
// }

__global__ void build_probes_stack(pycuda::complex<float> psi_pos[][{{y_sampling}}][{{x_sampling}}],
                    pycuda::complex<float> psi_k[{{y_sampling}}][{{x_sampling}}],
                    int z_size, float k_max, int *grid_step, float *grid_range){
                    int z_size, float k_max, int *grid_step, float grid_positions[][2]){
    const double pi = acos(-1.0);
    unsigned col_idx = blockIdx.x*blockDim.x + threadIdx.x;
    unsigned row_idx = blockIdx.y*blockDim.y + threadIdx.y;
    unsigned stk_idx = blockIdx.z*blockDim.z + threadIdx.z;
    if (col_idx < {{x_sampling}} && row_idx < {{y_sampling}} && stk_idx < z_size)
    {
        float kr = phase_shift(k_max, {{x_sampling}}, {{y_sampling}}, col_idx, row_idx, stk_idx, grid_step, grid_range);
        // float kr = phase_shift(k_max, {{x_sampling}}, {{y_sampling}}, col_idx, row_idx, stk_idx, grid_step, grid_range);
        float kr = phase_shift(k_max, {{x_sampling}}, {{y_sampling}}, col_idx, row_idx, grid_step, grid_positions[stk_idx]); 
        psi_pos[stk_idx][row_idx][col_idx]  = pycuda::complex<float>(cosf(2 * pi * kr), sinf(2 * pi * kr));
        psi_pos[stk_idx][row_idx][col_idx] *= psi_k[row_idx][col_idx];
    }