Commit 06b8efab authored by Nouamane Laanait's avatar Nouamane Laanait
Browse files

minor updates (context creation/destruction, scattering calcs....)

parent 8b73fe99
Loading
Loading
Loading
Loading
+7 −4
Original line number Diff line number Diff line
@@ -341,10 +341,13 @@ class MSAHybrid(MSA):
        def _clean_up():
            global ctx
            if ctx is not None:
                #global ctx
                try:#global ctx
                    #ctx.push()
                    ctx.pop()
                    ctx.detach()
                ctx = None
                    #ctx = None
                except:
                    pass
            from pycuda.tools import clear_context_caches
            clear_context_caches()

+3 −1
Original line number Diff line number Diff line
@@ -8,6 +8,8 @@ def get_kinematic_reflection(unit_cell, top=3):
    xrd = XRDCalculator().get_pattern(unit_cell)
    hkls = np.array([list(itm.keys())[0] for itm in xrd.hkls])
    intens = xrd.y
    if top > intens.size:
        top = intens.size 
    top_ind = np.argsort(intens)[::-1][:top]
    hkl_vecs = hkls[top_ind]
    d_hkls = np.array(xrd.d_hkls)[top_ind]
+57 −40
Original line number Diff line number Diff line
@@ -25,49 +25,55 @@ def swap_out(lmdb_path):
    except subprocess.SubprocessError as e:
        print("rank %d: %s" % (comm_rank, format(e)))

    ## replace with lmdb from repo
    #user = os.environ.get('USER')
    #lmdb_repo = "/gpfs/alpine/lrn001/proj-shared/nl/sims/data/lmdb_bank_0405_3096"
    #lmdb_repo_list = os.listdir(lmdb_repo)
    #index = np.random.randint(0, len(lmdb_repo_list))
    #lmdb_path_src = os.path.join(lmdb_repo, lmdb_repo_list[index])
    #if not os.path.exists(lmdb_path_src):
    #    print('replacement file %s not found' % lmdb_path_src)
    #    return
    #src = lmdb_path_src 
    #trg = lmdb_path 
    #cp_args = "cp -r %s %s" %(src, trg)
    #cp_args = shlex.split(cp_args)
    #if not os.path.exists(trg):
    #    try:
    #        subprocess.run(cp_args, check=True)
    #    except subprocess.SubprocessError as e:
    #        print("rank %d: %s" % (comm_rank, format(e)))
    # replace with lmdb from repo
    user = os.environ.get('USER')
    lmdb_repo = "/gpfs/alpine/lrn001/proj-shared/nl/sims/data/lmdb_bank_64_256_256"
    lmdb_repo_list = os.listdir(lmdb_repo)
    index = np.random.randint(0, len(lmdb_repo_list))
    lmdb_path_src = os.path.join(lmdb_repo, lmdb_repo_list[index])
    if not os.path.exists(lmdb_path_src):
        print('replacement file %s not found' % lmdb_path_src)
        return
    src = lmdb_path_src 
    trg = lmdb_path 
    cp_args = "cp -r %s %s" %(src, trg)
    cp_args = shlex.split(cp_args)
    if not os.path.exists(trg):
        try:
            subprocess.run(cp_args, check=True)
        except subprocess.SubprocessError as e:
            print("rank %d: %s" % (comm_rank, format(e)))


def simulate(filehandle, h5g, idx= None, gpu_id=0, clean_up=False):
def simulate(filehandle, h5g, idx= None, gpu_id=0, record_names=["2d_potential_", "cbed_"], clean_up=False):
    try:
        # load cif and get sim params
        cif_path = h5g.attrs['cif_path'].decode('ascii')
        z_dir = h5g['z_dir'][()]
        y_dir = h5g['y_dir'][()]
        z_dir = [0,0,1]
        y_dir = np.array([[1,0,0],[0,1,0]])[np.random.randint(2)]
        #z_dir = h5g['z_dir'][()]
        #y_dir = h5g['y_dir'][()]
        slice_thickness = h5g.attrs['d_hkl']
        semi_angle = h5g.attrs['semi_angle']
        sampling = np.array([512,512])
        cell_dim = 100 
        #semi_angle = h5g.attrs['semi_angle']
        semi_angle = 0.01
        sampling = np.array([256,256])
        cell_dim = 50 
        slab_t = 200
        energy = 100e3
        grid_steps = np.array([32,32])
        grid_steps = np.array([8,8])
        probe_params = {'smooth_apert': True, 'scherzer': False, 'apert_smooth': 30, 
                'aberration_dict':{'C1':0., 'C3':0 , 'C5':0.}, 'spherical_phase': True}

        semi_angle = 0.01
        energy = np.random.randint(60,140)
        probe_params['aberration_dict']['C3'] = np.round(10**(np.random.rand()*7))
        # build supercell
        sp = SupercellBuilder(cif_path, verbose=False, debug=False)
        slice_thickness = max(1.0, min(5.0, get_slice_thickness(sp, direc=np.array([0,0,1]))))
        # filter out 
        angles = np.array(sp.structure.lattice.angles)
        angles = np.round(angles).astype(np.int)
        cutoff = np.array([90,90,90])
        tol = 2
        tol = 3
        cubic_cond = np.logical_not(np.logical_and(angles > cutoff - tol, angles < cutoff + tol)).any()
        hexag_cond_1 = np.logical_and(angles[:2] > cutoff[:2] - tol, angles[:2] < cutoff[:2] + tol).any()
        hexag_cond_2 = np.logical_and(angles[-1] > 120 - tol, angles[-1] < 120 + tol)
@@ -94,8 +100,8 @@ def simulate(filehandle, h5g, idx= None, gpu_id=0, clean_up=False):
        msa.multislice(bandwidth=1.)
    
        # process cbed and potential
        mask = msa.bandwidth_limit_mask(sampling, radius=1./3).astype(np.bool)
        proj_potential = process_potential(msa.potential_slices, mask=None, normalize=True, fp16=True)
        #mask = msa.bandwidth_limit_mask(sampling, radius=1./3).astype(np.bool)
        proj_potential = process_potential(msa.potential_slices, sampling=sampling, mask=None, normalize=True, fp16=True)
        cbed = process_cbed(msa.probes, normalize=True, fp16=True)
        
        # 
@@ -111,14 +117,16 @@ def simulate(filehandle, h5g, idx= None, gpu_id=0, clean_up=False):

        # check data integrity
        has_nan = np.all(np.isnan(cbed)) or np.all(np.isnan(proj_potential))
        wrong_shape = cbed.shape != (1024, 512, 512) or proj_potential.shape != (1, 512, 512)
        true_cbed_shape = (np.prod(grid_steps),) + tuple(sampling)
        true_pot_shape = (1,) + tuple(sampling)
        wrong_shape = cbed.shape != true_cbed_shape or proj_potential.shape != true_pot_shape 
        if has_nan or wrong_shape:
            print("rank=%d, skipped simulation=%s, error=NaN" % (comm_rank, cif_path))
            return False
        else:
            # write to h5 / tfrecords / lmdb
            if isinstance(filehandle, lmdb.Transaction):
                write_lmdb(filehandle, idx , cbed, proj_potential)
                write_lmdb(filehandle, idx , cbed, proj_potential, record_names=record_names)
            print('rank=%d, finished simulation=%s' % (comm_rank, cif_path))
            return True
    except Exception as e:
@@ -149,10 +157,13 @@ def generate_eval_data(samples, h5_params, outdir_path, save_mode="h5", runtime=
        env = lmdb.open(lmdbpath, map_size=int(100e9), map_async=True, writemap=True, create=True) # max of 100 GB
        with env.begin(write=True) as txn:
            # write lmdb headers
            record_names = ["2d_potential_", "cbed_"]
            headers = {b"input_dtype": bytes('float16', "ascii"),
                       b"input_shape": np.array([1024,512,512]).tostring(),
                       b"output_shape": np.array([1,512,512]).tostring(),
                       b"output_dtype": bytes('float16', "ascii")}
                       b"input_shape": np.array([64,256,256]).tostring(),
                       b"output_shape": np.array([1,256,256]).tostring(),
                       b"output_dtype": bytes('float16', "ascii"),
                       b"output_name": bytes(record_names[0], "ascii"),
                       b"input_name": bytes(record_names[1], "ascii")}
            for key, val in headers.items():
                txn.put(key, val)
            env.sync()
@@ -166,7 +177,8 @@ def generate_eval_data(samples, h5_params, outdir_path, save_mode="h5", runtime=
                    print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))
                if (time() - t_elaps) < runtime: 
                    #try:
                    status = simulate(txn, h5g, idx=idx-fail, gpu_id=int(np.mod(comm_rank, 6)), clean_up=manual)
                    status = simulate(txn, h5g, idx=idx-fail, gpu_id=int(np.mod(comm_rank, 6)), clean_up=manual,
                            record_names=record_names)
                    if status:
                        env.sync()
                        success += 1
@@ -195,10 +207,13 @@ def generate_training_data(samples, h5_params, outdir_path, save_mode="h5", runt
        env = lmdb.open(lmdbpath, map_size=int(100e9), map_async=True, writemap=True, create=True) # max of 100 GB
        with env.begin(write=True) as txn:
            # write lmdb headers
            record_names = ["2d_potential_", "cbed_"]
            headers = {b"input_dtype": bytes('float16', "ascii"),
                       b"input_shape": np.array([1024,512,512]).tostring(),
                       b"output_shape": np.array([1,512,512]).tostring(),
                       b"output_dtype": bytes('float16', "ascii")}
                       b"input_shape": np.array([64,256,256]).tostring(),
                       b"output_shape": np.array([1,256,256]).tostring(),
                       b"output_dtype": bytes('float16', "ascii"),
                       b"output_name": bytes(record_names[0], "ascii"),
                       b"input_name": bytes(record_names[1], "ascii")}
            for key, val in headers.items():
                txn.put(key, val)
            env.sync()
@@ -214,7 +229,8 @@ def generate_training_data(samples, h5_params, outdir_path, save_mode="h5", runt
                    print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))

                if (time() - t_elaps) < runtime:
                    status = simulate(txn, h5g, idx=idx-fail, gpu_id=int(np.mod(comm_rank, 6)), clean_up=manual)
                    status = simulate(txn, h5g, idx=idx-fail, gpu_id=int(np.mod(comm_rank, 6)), clean_up=manual, 
                            record_names=record_names)
                    if status:
                        env.sync()
                        success += 1
@@ -233,6 +249,7 @@ def generate_training_data(samples, h5_params, outdir_path, save_mode="h5", runt

def get_samples(h5_params, ratio=0.8, time_std=2):
    samples = np.array(list(h5_params.keys()))
    np.random.shuffle(samples)
    times = np.array([h5_params[sample].attrs['time'] for sample in samples])
    mean_time = times.mean()
    std_time = times.std()
+14 −8
Original line number Diff line number Diff line
@@ -87,13 +87,13 @@ def write_tfrecord(tfrecord_writer, cbed, potential, params):
    tfrecord_writer.write(example.SerializeToString()) 
    return

def write_lmdb(txn, idx, cbed, potential, params=None):
def write_lmdb(txn, idx, cbed, potential, record_names=["2d_potential_", "cbed_"], params=None):
    # barebone writing to file
    key = bytes('potential_%s' %format(idx), "ascii")
    key = bytes('%s%s' %(record_names[0],format(idx)), "ascii")
    sample = potential.flatten()
    sample = sample.tostring()
    txn.put(key, sample)
    key = bytes('cbed_%s' %format(idx), "ascii")
    key = bytes('%s%s' %(record_names[1],format(idx)), "ascii")
    sample = cbed.flatten()
    sample = cbed.tostring()
    txn.put(key, sample)
@@ -101,12 +101,12 @@ def write_lmdb(txn, idx, cbed, potential, params=None):
    # need to figure out how to write params for each sample
    return

def process_potential(pot_slices, normalize=True, mask=None, sampling=None, scale=[0,1], expand_dim=True, fp16=False):
def process_potential(pot_slices, normalize=True, mask=None, sampling=None, scale=[-1,1], expand_dim=True, fp16=False):
    proj_potential = np.imag(pot_slices).mean(0)
    proj_potential = gaussian_filter(proj_potential,1.2)
    snapshot = slice(int(proj_potential.shape[0]// 4), int(3 * proj_potential.shape[1]//4))
    proj_potential = proj_potential[snapshot, snapshot]
    proj_potential = resize(proj_potential,(512, 512), preserve_range=True, mode='constant', order=4)
    proj_potential = resize(proj_potential,sampling, preserve_range=True, mode='constant', order=4)
    #if mask is None:
    #    pass
        #mask = np.ones((sampling, sampling), dtype=np.bool)
@@ -116,7 +116,7 @@ def process_potential(pot_slices, normalize=True, mask=None, sampling=None, scal
    #    mask = np.logical_not(mask)
    #    proj_potential[mask] = 0
    proj_potential = (proj_potential - proj_potential.mean())/max(proj_potential.std(), 1./np.sqrt(proj_potential.size)) 
    proj_potential = proj_potential - proj_potential.min()
    #proj_potential = proj_potential - proj_potential.min()
    #proj_potential = (proj_potential - proj_potential.min())/(proj_potential.max() - proj_potential.min())
    #proj_potential = proj_potential * (scale[-1] - scale[0]) + scale[0]
    proj_potential = np.expand_dims(proj_potential, axis=0)
@@ -130,9 +130,9 @@ def process_potential(pot_slices, normalize=True, mask=None, sampling=None, scal
    return proj_potential

def process_cbed(cbed, normalize=True, scale=[-1, 1], fp16=False):
    cbed = np.sqrt(cbed)
    #cbed = np.sqrt(cbed)
    #cbed = cbed ** (1./3)
    cbed = (cbed - cbed.mean())/max(cbed.std(), 1./np.sqrt(cbed[0].size)) 
    cbed = (cbed - np.mean(cbed, axis=(1,-1), keepdims=True))/np.std(cbed, axis=(1,-1), keepdims=True)
    #cbed = (cbed - np.min(cbed, axis=(1,-1), keepdims=True))/(np.max(cbed, axis=(1,-1), keepdims=True) - np.min(cbed, axis=(1,-1), keepdims=True))
    #cbed = cbed * (scale[-1] - scale[0]) + scale[0]
    cbed = cbed.astype(np.float16)
@@ -195,3 +195,9 @@ def get_sim_params(sp_cell, slab_t= 200, sampling=np.array([512,512]), d_cutoff=
                'aberration_dict':{'C1':0., 'C3':0 , 'C5':0.}, 'spherical_phase': True}
    return sim_params

def get_slice_thickness(sp_cell, direc=np.array([0,0,1])):
    hkls, dhkls = get_kinematic_reflection(sp_cell.structure,top=10)
    if hkls[0].size > 3: # hexagonal systems    
        hkls = np.array([[itm[0], itm[1], itm[-1]] for itm in hkls])
    idx = np.argmin(np.abs(np.cross(hkls,direc).sum(1)))
    return dhkls[idx]