Commit bffaffa2 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

adding bool option to propagate beam to multislice method

parent dcdfd18b
Loading
Loading
Loading
Loading
+22 −22
Original line number Diff line number Diff line
@@ -443,7 +443,6 @@ class MSAGPU(MSAHybrid):
            from pycuda.tools import clear_context_caches
            clear_context_caches()
    
    
    def build_potential_slices(self, slice_thickness):
        # find number of slices and atomic sites per slice
        self.slice_t = slice_thickness
@@ -643,7 +642,7 @@ class MSAGPU(MSAHybrid):
            grid_1d = (int((shape_x * shape_y) / block_1d[0]), int(shape_z), 1)
            return block_1d, grid_1d

    def multislice(self, bandwidth=1/3, unified_mem=False, batch_size=256):
    def multislice(self, bandwidth=1/3, unified_mem=False, batch_size=256, transmit=True):
        """

        :param bandwidth:
@@ -762,7 +761,7 @@ class MSAGPU(MSAHybrid):
                num_probes = np.int32(self.probe_positions[phase][batch].shape[0])
                self.print_debug('batch: %s' % format(batch))
                self.__propagate_beams(num_probes, batch, probe_d, propag_d, psi_k_d, norm_const, grid_steps_d, grid_range_d,
                                     self.probes[phase][batch], plan, ones_d, stream)
                                     self.probes[phase][batch], plan, ones_d, stream, transmit=transmit)
            # ctx.synchronize()
        # 4. clean-up
            for plan, probe_d, norm_const in zip(plans, probes_d, norm_consts):
@@ -793,7 +792,7 @@ class MSAGPU(MSAHybrid):

    def __propagate_beams(self, num_probes, batch, psi_pos_d, propag_d, psi_k_d,
                        norm_const_d, grid_steps_d, grid_range_d,
                        psi_x_pos_pin, fft_plan_probe, ones_d, stream):
                        psi_x_pos_pin, fft_plan_probe, ones_d, stream, transmit=True):
        """
        :param batch:
        :param psi_pos_d:
@@ -835,6 +834,7 @@ class MSAGPU(MSAHybrid):

        # 2. Propagate probes through atomic potential
        # ctx.synchronize()
        if transmit:
            for i in range(self.num_slices):
                # self.print_debug('Atomic potential slice #%d' % i)
                multwise_stack_func.prepared_async_call(grid_3d, block_3d, stream, psi_pos_d, self.pot_dev_ptr, num_probes, np.int32(i),