Commit 73b967a0 authored by Nouamane Laanait's avatar Nouamane Laanait
Browse files

fully-functional sim ensembles code

parent 319c6313
Loading
Loading
Loading
Loading
+34 −32
Original line number Diff line number Diff line
@@ -33,22 +33,20 @@ def parse_cif_path(cif_path):
    spgroup_num = re.findall('\d+',spgroup)[0]
    return spgroup_num, matname

def write_h5(h5file):
# def h5py_export(mat_g, material, sim_params_dir, cbed, pot):
    cbed_data = process_cbed(cbed)
    potential_data = process_pot(pot)
    json_labels = load_json_label(material+'.json', sim_params_dir)
def write_h5(h5group, cbed, potential, params):
    # try:
    dset_cbed = mat_g.create_dataset('CBED', shape=cbed_data.shape, data=cbed_data,
            dtype=np.float16)
    dset_pot = mat_g.create_dataset('potential', shape=potential_data.shape,
            data=potential_data, dtype=np.float16)
    num_itms = len(h5group.items())
    g = h5group.create_group('sample_%d' % num_itms)
    dset_cbed = g.create_dataset('CBED',  data=cbed)
    dset_pot = g.create_dataset('potential', data=potential)
    # need to figure out how to assign attributes to each dset and the parent group. 
    # for key in json_labels
    for key, itm in json_labels['sim'].items():
        dset_cbed.attrs[key] = itm
    for key, itm in json_labels['label'].items():
        dset_pot.attrs[key] = itm
    return potential_data.min(), potential_data.max(), potential_data.mean()
#    for key, itm in json_labels['sim'].items():
#        dset_cbed.attrs[key] = itm
#    for key, itm in json_labels['label'].items():
#        dset_pot.attrs[key] = itm
#    return potential_data.min(), potential_data.max(), potential_data.mean()
    return

def set_sim_params(unit_cell):
    """
@@ -56,7 +54,7 @@ def set_sim_params(unit_cell):
    """
    pass
    
def simulate(cif_path, gpu_rank=0, clean_up=False):
def simulate(h5g, cif_path, gpu_rank=0, clean_up=False):
    # build supercell
    sp = SupercellBuilder(cif_path, verbose=False, debug=False)
    sim_params = set_sim_params(sp)
@@ -66,17 +64,17 @@ def simulate(cif_path, gpu_rank=0, clean_up=False):
    sp.make_orthogonal_supercell(supercell_size=np.array([2*34.6,2*34.6,198.0]),
                             projec_1=y_dir, projec_2=z_dir)
    # set simulation params
    slice_thickness = 0.25 # Angstroms
    slice_thickness = 0.5 # Angstroms
    en = 100 # keV
    semi_angle= 4e-3 # radians
    max_ang = 200e-3 # radians
    step = 2.5 # Angstroms
    semi_angle= 10e-3 # radians
    max_ang = 150e-3 # radians
    step = 2.1 # Angstroms
    aberration_params = {'C1':500., 'C3': 3.3e7, 'C5':44e7}
    probe_params = {'smooth_apert': True, 'scherzer': False, 'apert_smooth': 60, 
                'aberration_dict':aberration_params, 'spherical_phase': True}
    
    # simulate
    msa = MSAGPU(en, semi_angle, sp.supercell_sites, sampling=np.array([512,512]),
    msa = MSAGPU(en, semi_angle, sp.supercell_sites, sampling=np.array([256,256]),
                 verbose=False, debug=False)
    ctx = msa.setup_device(gpu_rank=gpu_rank)
    msa.calc_atomic_potentials()
@@ -88,31 +86,35 @@ def simulate(cif_path, gpu_rank=0, clean_up=False):
    msa.multislice()
    
    # write to h5
    write_h5(None)
    print('rank %d: finished simulation of %s' % cif_path)
    write_h5(h5g, msa.probes, msa.potential_slices.sum(0), None)
    print('rank=%d, simulation=%s' % (comm_rank, cif_path))
    
    # clean-up context
    # clean-up context and/or allocated memory
    if clean_up and ctx is not None:
        msa.clean_up(ctx=ctx, vars=msa.vars)

    else:
        msa.clean_up(ctx=None, vars=msa.vars)

def main(cifdir_path, h5dir_path):
    cifpath_list = get_cif_paths(cifdir_path)
    cifpath_list = np.array(cifpath_list)[::-1]
    h5path = os.path.join(h5dir_path, 'batch_%d.h5'% comm_rank)
    if os.path.exists(h5path):
        mode ='r+'
    else:
        mode ='w'
    with h5py.File(h5path, mode=mode) as f:
        for idx in range(comm_rank, 100, comm_size):
        for idx in range(comm_rank, len(cifpath_list), comm_size):
            cif_path = cifpath_list[idx]
            manual = idx < 100 - comm_size 
            manual = idx < (len(cifpath_list) - comm_size) 
            spgroup_num, matname = parse_cif_path(cif_path)
            try:
                h5g = f.create_group(matname)
            except Exception as e:
                print("rank=%d" % comm_rank, e, "group=%s exists" % matname)
                h5g = f[matname]
            if comm_rank == 0:
                print('current idx: %d' %idx)
            simulate(cif_path, gpu_rank=int(np.mod(comm_rank, 4)), clean_up=manual)

            simulate(h5g, cif_path, gpu_rank=int(np.mod(comm_rank, 6)), clean_up=manual)

if __name__ == "__main__":
    if len(sys.argv) > 2: