Commit 009f656a authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

minor bug fixes to lmdb file writing

parent a229d4f0
Loading
Loading
Loading
Loading
+18 −6
Original line number Diff line number Diff line
@@ -6,15 +6,25 @@ import numpy as np
def read_lmdb(args):
    lmdb_path, delete = args[:]
    if delete:
        env = lmdb.open(lmdb_path, map_size=int(100e9), readahead=False, readonly=False, writemap=False, lock=True)
        env = lmdb.open(lmdb_path, map_size=int(100e9), readahead=False, readonly=False, 
                        writemap=True, lock=True, map_async=True)
    else:
        env = lmdb.open(lmdb_path, readahead=False, readonly=True, writemap=False, lock=False)
    num_samples = env.stat()['entries'] - 4 ## TODO: remove hard-coded # of headers by storing #_headers key
    with env.begin(write=False) as txn:
        input_shape = np.frombuffer(txn.get(b"input_shape"), dtype='int64')
        output_shape = np.frombuffer(txn.get(b"output_shape"), dtype='int64')
        input_dtype = np.dtype(txn.get(b"input_dtype").decode("ascii"))
        output_dtype = np.dtype(txn.get(b"output_dtype").decode("ascii"))
        output_name = txn.get(b"output_name").decode("ascii")
        input_name = txn.get(b"input_name").decode("ascii")
        num_headers = int.from_bytes(txn.get(b"header_entries"),"little")
#     num_samples = (env.stat()['entries'] - 6)//2 ## TODO: remove hard-coded # of headers by storing #samples key, val
    num_samples = int((env.stat()['entries'] - 6)/2)
    first_record = 0
    records = np.arange(first_record, num_samples//2)
    data_specs={'label_shape': [1,512,512], 'image_shape': [1024, 512, 512], 
          'label_dtype':'float16', 'image_dtype': 'float16', 'label_key':'potential_', 'image_key': 'cbed_'}
    print('file=%s, samples=%d' %(lmdb_path.split('/')[-1], num_samples//2))
    records = np.arange(first_record, num_samples)
    data_specs={'label_shape': list(output_shape), 'image_shape': list(input_shape),
            'label_dtype':output_dtype, 'image_dtype': input_dtype, 'label_key':output_name, 'image_key': input_name}
    print('file=%s, samples=%d' %(lmdb_path.split('/')[-1], num_samples))
    with env.begin(write=delete, buffers=True) as txn:
        for idx in records:
            image_key = bytes(data_specs['image_key']+str(idx), "ascii")
@@ -34,6 +44,8 @@ def read_lmdb(args):
                    label = np.random.uniform(low=0.0, high=1.0, size=data_specs['label_shape']).astype(data_specs['label_dtype']) 
                    txn.put(image_key, image.flatten().tostring())
                    txn.put(label_key, label.flatten().tostring())
                    print('file=%s, sample=%d, replaced' %(lmdb_path.split('/')[-1], idx))
                    env.sync()

def main(lmdb_dir, delete=False):
    lmdb_files = os.listdir(lmdb_dir)
+19 −15
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@ comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
comm_rank = comm.Get_rank()

global counter
counter = 0

def swap_out(lmdb_path):
    # delete current lmdb dir
    rm_args = "rm -r %s" % lmdb_path
@@ -75,10 +78,10 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco
    slab_t = sim_params['slab_t']
    sim_params['space_group']= spgroup_num
    sim_params['material'] = matname
    energies = np.linspace(100,200,5)
    energies = np.linspace(100,200,4)
    np.random.shuffle(energies)
    write_counter = 0
    for current, energy in enumerate(energies):
    counter = 0
    for energy in energies:
        for index in range(len(sim_params['z_dirs'])): 
            # build supercell
            z_dir = sim_params['z_dirs'][index]
@@ -89,7 +92,7 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco

            # set simulation params
            slice_thickness = sim_params['d_hkl'][index]
            energy = sim_params['energy']
#             energy = sim_params['energy']
            semi_angle= sim_params['semi_angles'][index]
            probe_params = sim_params['probe_params']
            sampling = sim_params['sampling']
@@ -121,8 +124,8 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco
            sim_params = update_sim_params(sim_params, msa_cls=msa, sp_cls=sp)

            has_nan = np.all(np.isnan(cbed)) or np.all(np.isnan(proj_potential))
            wrong_shape = cbed.shape != (1024, 512, 512) or proj_potential.shape != (1, 512, 512)
            if has_nan :
            wrong_shape = cbed.shape != (64, 256, 256) or proj_potential.shape != (1, 128, 128)
            if has_nan or wrong_shape:
                # clean-up context and/or allocated memory
                print('rank=%d, found this many %d nan in cbed' %(comm_rank, np.where(np.isnan(cbed)==True)[0].size))
                print('rank=%d, found this many %d nan in proj_pot' %(comm_rank, np.where(np.isnan(proj_potential)==True)[0].size))
@@ -132,15 +135,15 @@ def simulate(filehandle, cif_path, idx= None, gpu_ctx=None, clean_up=False, reco
                if isinstance(filehandle, h5py.Group):
                    write_h5(filehandle, cbed, proj_potential, sim_params)
                elif isinstance(filehandle, lmdb.Transaction):
                    write_lmdb(filehandle, idx+current , cbed, proj_potential, record_names=record_names)
                    write_counter += 1
                    print('rank=%d, wrote sim_index=%d' % (comm_rank, write_counter))
                    write_lmdb(filehandle, idx + counter , cbed, proj_potential, record_names=record_names)
                    print('rank=%d, wrote sim_index=%d' % (comm_rank, idx+counter))
                    counter += 1
                elif isinstance(filehandle, tf.python_io.TFRecordWriter):
                    write_tfrecord(filehandle, cbed, proj_potential, sim_params)

            # free-up gpu memory
            msa.clean_up(ctx=None, vars=msa.vars)
    return True, write_counter
    return True, counter
    
def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_ctx=None, runtime=1800*0.9):
    t = time()
@@ -213,6 +216,7 @@ def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_
                       b"input_name": bytes(record_names[1], "ascii")}
            for key, val in headers.items():
                txn.put(key, val)
            txn.put(b"header_entries", bytes(len(list(headers.items()))))
            env.sync()
            
            # start simulation
@@ -226,23 +230,23 @@ def generate_data(cifpaths, outdir_path, save_mode="h5", data_mode="train", gpu_
                    print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))
                if (time() - t_elaps) < runtime:
                    try:
                        status, ret_counter = simulate(txn, cif_path, idx=idx+counter, gpu_ctx=gpu_ctx, clean_up=manual)
                        if ret_counter is not None:
                            counter += ret_counter
                        status, write_counter = simulate(txn, cif_path, idx=counter, gpu_ctx=gpu_ctx, clean_up=manual)
                        if status:
                            print('rank=%d, finished simulation=%s' % (comm_rank, cif_path.split('/')[-2:]))
#                             print('rank=%d, counter=%s' % (comm_rank, counter))
                            env.sync()
                            success += 1
                            counter += write_counter
                        else:
                            fail += 1
#                             print('rank=%d, counter=%s' % (comm_rank, counter))
                    except Exception as e:
                        print("rank=%d, skipped simulation=%s, error=%s" % (comm_rank, cif_path.split('/')[-2:], format(e)))
#                         print('rank=%d, counter=%s' % (comm_rank, counter))
                        fail += 1
                else:
                    env.sync()
                    break
        #if success < 4:
        #    swap_out(lmdbpath)

    #comm.Barrier()            
    # time the simulation run