Loading namsa/msa.py +22 −22 Original line number Diff line number Diff line Loading @@ -443,7 +443,6 @@ class MSAGPU(MSAHybrid): from pycuda.tools import clear_context_caches clear_context_caches() def build_potential_slices(self, slice_thickness): # find number of slices and atomic sites per slice self.slice_t = slice_thickness Loading Loading @@ -643,7 +642,7 @@ class MSAGPU(MSAHybrid): grid_1d = (int((shape_x * shape_y) / block_1d[0]), int(shape_z), 1) return block_1d, grid_1d def multislice(self, bandwidth=1/3, unified_mem=False, batch_size=256): def multislice(self, bandwidth=1/3, unified_mem=False, batch_size=256, transmit=True): """ :param bandwidth: Loading Loading @@ -762,7 +761,7 @@ class MSAGPU(MSAHybrid): num_probes = np.int32(self.probe_positions[phase][batch].shape[0]) self.print_debug('batch: %s' % format(batch)) self.__propagate_beams(num_probes, batch, probe_d, propag_d, psi_k_d, norm_const, grid_steps_d, grid_range_d, self.probes[phase][batch], plan, ones_d, stream) self.probes[phase][batch], plan, ones_d, stream, transmit=transmit) # ctx.synchronize() # 4. clean-up for plan, probe_d, norm_const in zip(plans, probes_d, norm_consts): Loading Loading @@ -793,7 +792,7 @@ class MSAGPU(MSAHybrid): def __propagate_beams(self, num_probes, batch, psi_pos_d, propag_d, psi_k_d, norm_const_d, grid_steps_d, grid_range_d, psi_x_pos_pin, fft_plan_probe, ones_d, stream): psi_x_pos_pin, fft_plan_probe, ones_d, stream, transmit=True): """ :param batch: :param psi_pos_d: Loading Loading @@ -835,6 +834,7 @@ class MSAGPU(MSAHybrid): # 2. Propagate probes through atomic potential # ctx.synchronize() if transmit: for i in range(self.num_slices): # self.print_debug('Atomic potential slice #%d' % i) multwise_stack_func.prepared_async_call(grid_3d, block_3d, stream, psi_pos_d, self.pot_dev_ptr, num_probes, np.int32(i), Loading Loading
namsa/msa.py +22 −22 Original line number Diff line number Diff line Loading @@ -443,7 +443,6 @@ class MSAGPU(MSAHybrid): from pycuda.tools import clear_context_caches clear_context_caches() def build_potential_slices(self, slice_thickness): # find number of slices and atomic sites per slice self.slice_t = slice_thickness Loading Loading @@ -643,7 +642,7 @@ class MSAGPU(MSAHybrid): grid_1d = (int((shape_x * shape_y) / block_1d[0]), int(shape_z), 1) return block_1d, grid_1d def multislice(self, bandwidth=1/3, unified_mem=False, batch_size=256): def multislice(self, bandwidth=1/3, unified_mem=False, batch_size=256, transmit=True): """ :param bandwidth: Loading Loading @@ -762,7 +761,7 @@ class MSAGPU(MSAHybrid): num_probes = np.int32(self.probe_positions[phase][batch].shape[0]) self.print_debug('batch: %s' % format(batch)) self.__propagate_beams(num_probes, batch, probe_d, propag_d, psi_k_d, norm_const, grid_steps_d, grid_range_d, self.probes[phase][batch], plan, ones_d, stream) self.probes[phase][batch], plan, ones_d, stream, transmit=transmit) # ctx.synchronize() # 4. clean-up for plan, probe_d, norm_const in zip(plans, probes_d, norm_consts): Loading Loading @@ -793,7 +792,7 @@ class MSAGPU(MSAHybrid): def __propagate_beams(self, num_probes, batch, psi_pos_d, propag_d, psi_k_d, norm_const_d, grid_steps_d, grid_range_d, psi_x_pos_pin, fft_plan_probe, ones_d, stream): psi_x_pos_pin, fft_plan_probe, ones_d, stream, transmit=True): """ :param batch: :param psi_pos_d: Loading Loading @@ -835,6 +834,7 @@ class MSAGPU(MSAHybrid): # 2. Propagate probes through atomic potential # ctx.synchronize() if transmit: for i in range(self.num_slices): # self.print_debug('Atomic potential slice #%d' % i) multwise_stack_func.prepared_async_call(grid_3d, block_3d, stream, psi_pos_d, self.pot_dev_ptr, num_probes, np.int32(i), Loading