Commit 531b5e5a authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

hack to avoid wrong probe positions when batch_size != self.num_probes

parent 23c6b667
Loading
Loading
Loading
Loading
+10 −23
Original line number Diff line number Diff line
@@ -230,7 +230,7 @@ class MSA(object):
            grid_range_stop = (0.5 + self.dims[:2]/2)/2
            x_pos, y_pos = np.mgrid[grid_range_start[0]:grid_range_stop[0]:-1j*grid_steps[0], 
                                    grid_range_start[1]:grid_range_stop[1]:-1j*grid_steps[1]]
            grid_steps_x, grid_steps_y = np.floor_divide(grid_steps, 2)  
            grid_steps_x, grid_steps_y = grid_steps.astype(np.int) 
            grid_range_x = np.array([grid_range_start[0], grid_range_stop[0]])
            grid_range_y = np.array([grid_range_start[1], grid_range_stop[1]])
            probe_pos = np.array([[-x, y] for y, x in zip(y_pos.flatten()[::-1], x_pos.flatten())])
@@ -604,23 +604,6 @@ class MSAGPU(MSAHybrid):
        apert_d.free()
        cufft.cufftDestroy(fft_plan.handle)
        
        # unregister host memory
        #self.apert.unregister()
        #self.psi_k.unregister()
        #self.psi.unregister()

    # def generate_probe_positions(self, probe_step=np.array([0.1, 0.1]), probe_range=np.array([[0., 1.0], [0., 1.0]])):
    #     grid_steps_x, grid_steps_y = np.floor(np.diff(probe_range).flatten() * self.dims[:2] / probe_step).astype(np.int)
    #     grid_range_x, grid_range_y = [(probe_range[i] - np.ones((2,)) * 0.5) * self.dims[i]
    #                                   for i in range(2)]
    #     y_pos, x_pos = np.mgrid[grid_range_y[0]: grid_range_y[1]: -1j * grid_steps_y,
    #                    grid_range_x[0]: grid_range_x[1]: -1j * grid_steps_x]
    #     probe_pos = np.array([[y, -x] for y, x in zip(y_pos.flatten()[::-1], x_pos.flatten())])
    #     self.grid_steps = np.array([grid_steps_x, grid_steps_y])
    #     self.grid_range = np.array([grid_range_x, grid_range_y]).flatten()
    #     self.probe_positions = probe_pos
    #     self.num_probes = np.int32(probe_pos.shape[0])

    @staticmethod
    def _get_blockgrid(shapes, mode='2D'):
        # define block/grid threads
@@ -651,13 +634,15 @@ class MSAGPU(MSAHybrid):
            grid_1d = (int((shape_x * shape_y) / block_1d[0]), int(shape_z), 1)
            return block_1d, grid_1d

    def multislice(self, bandwidth=1/3, unified_mem=False, batch_size=256, transmit=True):
    def multislice(self, bandwidth=1/3, unified_mem=False, batch_size=None, transmit=True):
        """

        :param bandwidth:
        :return:
        """

        ## TODO: when batch_size != self.num_probes, probe positions get scrambled!!!
        if batch_size == None:
            batch_size = self.num_probes 
        # checks
        if isinstance(self.potential_slices, np.ndarray) is False:
            warn('Potential slices must be calculated first before calling multi_slice\n. '
@@ -684,7 +669,7 @@ class MSAGPU(MSAHybrid):
        self.print_debug('block, grid:', block_3d, grid_3d)

        # allocate memory
        self.probes = np.empty((num_probes, shape_y, shape_x), dtype=np.complex64)
        # self.probes = np.empty((self.num_probes, shape_y, shape_x), dtype=np.complex64)
        self.propag = cuda.aligned_empty((int(self.sampling[0]), int(self.sampling[1])), np.complex64)
        self.vars.append(self.propag.base)
        self.propag = cuda.register_host_memory(self.propag)
@@ -739,6 +724,7 @@ class MSAGPU(MSAHybrid):
                phase = slice(i * self.max_probes, self.num_probes)
            phases.append(phase)
        self.print_debug('Simulation split into %d serial phases.' % len(phases))
        self.print_debug('Phases: %s ' %format(phases))
        for (i, phase) in enumerate(phases):
            t = time()
            if self.probes[phase].shape[0] == 0: break
@@ -771,7 +757,7 @@ class MSAGPU(MSAHybrid):
                self.print_debug('batch: %s' % format(batch))
                self.__propagate_beams(num_probes, batch, probe_d, propag_d, psi_k_d, norm_const, grid_steps_d, grid_range_d,
                                     self.probes[phase][batch], plan, ones_d, stream, transmit=transmit)
            # ctx.synchronize()
            ctx.synchronize()
        # 4. clean-up
            for plan, probe_d, norm_const in zip(plans, probes_d, norm_consts):
               cufft.cufftDestroy(plan.handle)
@@ -781,7 +767,8 @@ class MSAGPU(MSAHybrid):
               del probe_d, norm_const, plan
            ctx.synchronize()
            self.print_verbose('finished simulation phase #%d' % i)
            self.probes[phase][batch] = self.probes[phase][batch]/self.normalization
            self.probes /= self.normalization
            # self.probes[phase][batch] = self.probes[phase][batch]/self.normalization
            sim_t = time()-t
            self.print_verbose('Propagated %d probes in %2.4f s' % (self.probe_positions[phase].shape[0], sim_t))