Commit a229d4f0 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

fix to file writing for multiple sims from the same materials

parent 52eb891c
Loading
Loading
Loading
Loading
+21 −16
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco
    latts = np.array(sp.structure.lattice.abc)
    if np.any(latts >= 10.) or np.all(latts >= 7):
        print('rank=%d, skipped simulation=%s, latt. const. too large=%s' % (comm_rank, cif_path, format(latts)))
        return    
        return False, None   
    angles = np.array(sp.structure.lattice.angles)
    angles = np.round(angles).astype(np.int)
    cutoff = np.array([90,90,90])
@@ -62,12 +62,12 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco
    hexag_cond_1 = np.logical_and(angles[:2] > cutoff[:2] - tol, angles[:2] < cutoff[:2] + tol).any()
    hexag_cond_2 = np.logical_and(angles[-1] > 120 - tol, angles[-1] < 120 + tol)
    hexag_cond = np.logical_not(hexag_cond_1 and hexag_cond_2)
#     if cubic_cond:
#         if hexag_cond:
#             print("rank=%d, skipped simulation=%s, msg=not cubic/hexagonal" % (comm_rank, cif_path.split('/')[-2:]))
#             return False
#         else:
#             pass
    if cubic_cond:
        if hexag_cond:
            print("rank=%d, skipped simulation=%s, msg=not cubic/hexagonal" % (comm_rank, cif_path.split('/')[-2:]))
            return False, None
        else:
            pass
        
    sim_params = get_sim_params(sp, slab_t= 100, cell_dim = 50, grid_steps=np.array([8,8]), orientation_num=5, 
                                sampling=np.array([256,256]))
@@ -75,8 +75,9 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco
    slab_t = sim_params['slab_t']
    sim_params['space_group']= spgroup_num
    sim_params['material'] = matname
    energies = np.linspace(60,200,10)
    energies = np.linspace(100,200,5)
    np.random.shuffle(energies)
    write_counter = 0
    for current, energy in enumerate(energies):
        for index in range(len(sim_params['z_dirs'])): 
            # build supercell
@@ -119,7 +120,6 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco
            # update sim_params dict
            sim_params = update_sim_params(sim_params, msa_cls=msa, sp_cls=sp)

        #     print(proj_potential.shape, cbed.shape)
            has_nan = np.all(np.isnan(cbed)) or np.all(np.isnan(proj_potential))
            wrong_shape = cbed.shape != (1024, 512, 512) or proj_potential.shape != (1, 512, 512)
            if has_nan :
@@ -133,12 +133,14 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco
                    write_h5(filehandle, cbed, proj_potential, sim_params)
                elif isinstance(filehandle, lmdb.Transaction):
                    write_lmdb(filehandle, idx+current , cbed, proj_potential, record_names=record_names)
                    write_counter += 1
                    print('rank=%d, wrote sim_index=%d' % (comm_rank, write_counter))
                elif isinstance(filehandle, tf.python_io.TFRecordWriter):
                    write_tfrecord(filehandle, cbed, proj_potential, sim_params)

            # free-up gpu memory
            msa.clean_up(ctx=None, vars=msa.vars)
    return True
    return True, write_counter
    
def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_ctx=None, runtime=1800*0.9):
    t = time()
@@ -166,7 +168,7 @@ def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_
                        print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))

                    try: 
                        status = simulate(h5g, cif_path, gpu_ctx=gpu_ctx, clean_up=manual)
                        status, _ = simulate(h5g, cif_path, gpu_ctx=gpu_ctx, clean_up=manual)
                        if status:
                            print('rank=%d, finished simulation=%s' % (comm_rank, cif_path.split('/')[-2:]))
                            f.flush()
@@ -187,7 +189,7 @@ def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_
                if comm_rank == 0 and bool(idx % 500):
                    print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))
                try: 
                    status = simulate(tfrec, cif_path, gpu_ctx=gpu_ctx, clean_up=manual)
                    status, _ = simulate(tfrec, cif_path, gpu_ctx=gpu_ctx, clean_up=manual)
                    if status:
                        print('rank=%d, finished simulation=%s' % (comm_rank, cif_path.split('/')[-2:]))
                    else:
@@ -199,7 +201,7 @@ def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_
    # LMDB            
    elif save_mode == "lmdb":
        lmdbpath = os.path.join(outdir_path, 'batch_%s_%d.db' % (data_mode, comm_rank))
        env = lmdb.open(lmdbpath, map_size=int(100e9), map_async=True, writemap=True, create=True) # max of 100 GB
        env = lmdb.open(lmdbpath, map_size=int(10e12), map_async=True, writemap=True, create=True) # max of 100 GB
        with env.begin(write=True) as txn:
            # write lmdb headers
            record_names = ["2d_potential_", "cbed_"]
@@ -216,6 +218,7 @@ def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_
            # start simulation
            fail = 0
            success = 0
            counter = 0
            for (idx, cif_path) in enumerate(cifpaths[comm_rank:num_sims:comm_size]):
                manual = idx < ( num_sims - comm_size) 
                spgroup_num, matname = parse_cif_path(cif_path)
@@ -223,7 +226,9 @@ def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_
                    print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))
                if (time() - t_elaps) < runtime:
                    try:
                        status = simulate(txn, cif_path, idx=idx-fail, gpu_ctx=gpu_ctx, clean_up=manual)
                        status, ret_counter = simulate(txn, cif_path, idx=idx+counter, gpu_ctx=gpu_ctx, clean_up=manual)
                        if ret_counter is not None:
                            counter += ret_counter
                        if status:
                            print('rank=%d, finished simulation=%s' % (comm_rank, cif_path.split('/')[-2:]))
                            env.sync()
@@ -270,9 +275,9 @@ def main(cifdir_path, outdir_path, save_mode, runtime=1800):
    global t_elaps
    t_elaps = time()
    cif_paths = get_cif_paths(cifdir_path)
    samples_train, samples_dev, samples_test = get_samples(cif_paths, ratio=0.9)
    samples_train, samples_dev, samples_test = get_samples(cif_paths, ratio=0.95)
    ctx = setup_device(gpu_id=int(np.mod(comm_rank, 6)))
    generate_data(samples_train, outdir_path, save_mode=save_mode, data_mode='train', runtime=runtime*0.9, gpu_ctx=ctx)
    generate_data(samples_train, outdir_path, save_mode=save_mode, data_mode='train', runtime=runtime*0.95, gpu_ctx=ctx)
    print('rank=%d, finished simulating training data' % comm_rank)
#     generate_data(samples_dev, outdir_path, save_mode=save_mode, data_mode='dev', runtime=runtime*0.9, gpu_ctx=ctx)
#     print('rank=%d, finished simulating dev data' % comm_rank)