Commit 1080dcc7 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

adding lmdb_write capability

parent 531b5e5a
Loading
Loading
Loading
Loading
+47 −36
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import h5py
from mpi4py import MPI
from itertools import chain
import tensorflow as tf
import lmdb

comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
@@ -14,7 +15,7 @@ comm_rank = comm.Get_rank()



def simulate(filehandle, cif_path, gpu_id=0, clean_up=False):
def simulate(filehandle, cif_path, idx= None, gpu_id=0, clean_up=False):
    # load cif and get sim params
    spgroup_num, matname = parse_cif_path(cif_path)
    index = 0 
@@ -53,17 +54,19 @@ def simulate(filehandle, cif_path, gpu_id=0, clean_up=False):
    
    # process cbed and potential
    mask = msa.bandwidth_limit_mask(sampling, radius=1./3).astype(np.bool)
    proj_potential = process_potential(msa.potential_slices, mask=mask)
    proj_potential = process_potential(msa.potential_slices, mask=mask, fp16=True)
    cbed = process_cbed(msa.probes, fp16=True)
    
    # update sim_params dict
    sim_params = update_sim_params(sim_params, msa_cls=msa, sp_cls=sp)
    
    # write to h5 / tfrecords
    # write to h5 / tfrecords / lmdb
    if isinstance(filehandle, h5py.Group):
         write_h5(filehandle, msa.probes, proj_potential, sim_params)
    else:
         write_tfrecord(filehandle, msa.probes, proj_potential, sim_params)

         write_h5(filehandle, cbed, proj_potential, sim_params)
    elif isinstance(filehandle, lmdb.Transaction):
         write_lmdb(filehandle, idx + index, cbed, proj_potential, sim_params)
    elif isinstance(filehandle, tf.python_io.TFRecordWriter):
         write_tfrecord(filehandle, cbed, proj_potential, sim_params)
    print('rank=%d, simulation=%s' % (comm_rank, cif_path))
    
    # clean-up context and/or allocated memory
@@ -75,8 +78,11 @@ def simulate(filehandle, cif_path, gpu_id=0, clean_up=False):

def main(cifdir_path, outdir_path, save_mode="h5"):
    t = time()
    cifpath_list = get_cif_paths(cifdir_path)
    cifpaths = get_cif_paths(cifdir_path)
    batch_num, _ = np.divmod(comm_rank, 6)
    num_sims = cifpaths.size
    num_sims = 16
    
    if save_mode == "h5": 
        h5path = os.path.join(outdir_path, 'batch_%d.h5'% comm_rank)
        if os.path.exists(h5path):
@@ -84,9 +90,8 @@ def main(cifdir_path, outdir_path, save_mode="h5"):
        else:
            mode ='w'
        with h5py.File(h5path, mode=mode) as f:
            for idx in range(comm_rank, len(cifpath_list), comm_size):
                cif_path = cifpath_list[idx]
                manual = idx < ( len(cifpath_list) - comm_size) 
            for (idx, cif_path) in enumerate(cifpaths[comm_rank:num_sims:comm_size]):
                manual = idx < ( num_sims - comm_size) 
                spgroup_num, matname = parse_cif_path(cif_path)
                try:
                    h5g = f.create_group(matname)
@@ -94,46 +99,52 @@ def main(cifdir_path, outdir_path, save_mode="h5"):
                    print("rank=%d" % comm_rank, e, "group=%s exists" % matname)
                    h5g = f[matname]
                if comm_rank == 0 and bool(idx % 500):
                    print('time=%3.2f, idx= %d' %(time() - t, idx))
                    print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))
                simulate(h5g, cif_path, gpu_id=int(np.mod(comm_rank, 6)), clean_up=manual)
    else:
                
    elif save_mode == "tfrecord":
        tfrecpath = os.path.join(outdir_path, 'batch_%d.tfrecords'% comm_rank)   
        with tf.python_io.TFRecordWriter(tfrecpath) as tfrec:
            for idx in range(comm_rank, len(cifpath_list), comm_size):
                cif_path = cifpath_list[idx]
                manual = idx < ( len(cifpath_list) - comm_size) 
            for (idx, cif_path) in enumerate(cifpaths[comm_rank:num_sims:comm_size]):
                manual = idx < ( num_sims - comm_size) 
                spgroup_num, matname = parse_cif_path(cif_path)
                if comm_rank == 0 and bool(idx % 500):
                    print('time=%3.2f, idx= %d' %(time() - t, idx))
                    print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))
                simulate(tfrec, cif_path, gpu_id=int(np.mod(comm_rank, 6)), clean_up=manual)
                
    elif save_mode == "lmdb":
        lmdbpath = os.path.join(outdir_path, 'batch_%d.db' % comm_rank)
        env = lmdb.open(lmdbpath, map_size=int(50e9*1024)) # max of 50 GB
        with env.begin(write=True) as txn:
            for (idx, cif_path) in enumerate(cifpaths[comm_rank:num_sims:comm_size]):
                manual = idx < ( num_sims - comm_size) 
                spgroup_num, matname = parse_cif_path(cif_path)
                if comm_rank == 0 and bool(idx % 500):
                    print('time=%3.2f, num_sims= %d' %(time() - t, idx * comm_size))
                simulate(txn, cif_path, idx=idx, gpu_id=int(np.mod(comm_rank, 6)), clean_up=manual)
            # write lmdb headers
            headers = {b"input_dtype": bytes('float16', "ascii"),
                       b"input_shape": np.array([1024,512,512]).tostring(),
                       b"output_shape": np.array([1,512,512]).tostring(),
                       b"output_dtype": bytes('float16', "ascii")}
            for key, val in headers.items():
                txn.put(key, val)
                
    # time the simulation run        
    sim_t = time() - t
    if comm_rank == 0:
        print("took %3.3f seconds" % sim_t)    

def main_test(cifdir_path):
    cifpath_list = get_cif_paths(cifdir_path)
    idx = np.random.randint(0, len(cifpath_list))
    for _ in range(1000):
        cif_path = cifpath_list[idx]
        spgroup_num, matname = parse_cif_path(cif_path)
        sp = SupercellBuilder(cif_path, verbose=False, debug=False)
        sim_params = set_sim_params(sp, energy=100e3, orientation_num=3, beam_overlap=2)
        y_dir, z_dir = sim_params['y_dirs'][0], sim_params['z_dirs'][0]
        sp.build_unit_cell()
        sp.make_orthogonal_supercell(supercell_size=np.array([2*34.6,2*34.6,198.0]),
                             projec_1=y_dir, projec_2=z_dir)
        
        print("rank=%d, spgroup= %s, material=%s, z_dir=%s, y_dir=%s, d_hkl=%2.2f, semi_angle=%2.2f" 
              % (comm_rank, spgroup_num, matname, z_dir, y_dir, sim_params['d_hkl'][0], sim_params['semi_angles'][0]))
        if comm_rank == 0:
            print('current idx: %d' %idx)
    cifpaths_train, cifpaths_test= get_cif_paths(cifdir_path, ratio=0.2)
    print("train", cifpaths_train[:10])
    print("test", cifpaths_test[:10])
                
if __name__ == "__main__":
    if len(sys.argv) > 2:
        cifdir_path, outdir_path, save_mode = sys.argv[-3:]
        if save_mode not in ["h5","tfrecord"]:
            print("specify saving format")
        if save_mode not in ["h5", "tfrecord", "lmdb"]:
            print("saving format not of h5, tfrecord, lmdb")
            sys.exit()
        main(cifdir_path, outdir_path, save_mode)
    elif len(sys.argv) == 2:
+29 −2
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ def pop_DS(lst):
        if '.DS_Store' in itm:
            lst.pop(i)

def get_cif_paths(root_path):
def get_cif_paths(root_path, ratio=None):
    space_group_dirs = os.listdir(root_path)
    pop_DS(space_group_dirs)
    cifpath_list = []
@@ -44,6 +44,11 @@ def get_cif_paths(root_path):
        cif_paths = [os.path.join(os.path.join(root_path,spg_dir),cif_name) for cif_name in cif_list]
        cifpath_list.append(cif_paths)
    cifpath_list = list(chain.from_iterable(cifpath_list))
    cifpath_list = np.array(cifpath_list)
    np.random.shuffle(cifpath_list)
    if ratio is not None:
        test_size = int(cifpath_list.size * ratio)
        return cifpath_list[:test_size], cifpath_list[test_size:]
    return cifpath_list 

def parse_cif_path(cif_path):
@@ -80,7 +85,21 @@ def write_tfrecord(tfrecord_writer, cbed, potential, params):
    tfrecord_writer.write(example.SerializeToString()) 
    return

def process_potential(pot_slices, mask=None, sampling=None):
def write_lmdb(txn, idx, cbed, potential, params):
    # barebone writing to file
    key = bytes('potential_%s' %format(idx), "ascii")
    sample = potential.flatten()
    sample = sample.tostring()
    txn.put(key, sample)
    key = bytes('cbed_%s' %format(idx), "ascii")
    sample = cbed.flatten()
    sample = cbed.tostring()
    txn.put(key, sample)
    
    # need to figure out how to write params for each sample
    return

def process_potential(pot_slices, mask=None, sampling=None, expand_dim=True, fp16=False):
    proj_potential = np.imag(pot_slices).sum(0)
    if mask is None:
        mask = np.ones((sampling, sampling), dtype=np.bool)
@@ -89,8 +108,16 @@ def process_potential(pot_slices, mask=None, sampling=None):
    else:
        mask = np.logical_not(mask)
    proj_potential[mask] = 0
    if expand_dim:
        proj_potential = np.expand_dims(proj_potential, axis=0)
    if fp16:
        return proj_potential.astype(np.float16)
    return proj_potential

def process_cbed(cbed, fp16=False):
    if fp16:
        return cbed.astype(np.float16)

def update_sim_params(sim_params, msa_cls=None, sp_cls=None):
    # msa params
    if msa_cls is not None: