Commit 145be2eb authored by Ferreira Da Silva, Rafael's avatar Ferreira Da Silva, Rafael
Browse files

Update train.py

parent 9142e286
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -290,6 +290,8 @@ if __name__ == "__main__":
    ## Preprocess training sequences
    data_np = ary

    start_time = datetime.now()

    ## Prepare the lev1, lev2 training data, for lev2, we have different strategies based on the mapping_mode
    if config['mapping_mode'] == 0:  # single rank for level 2
        lev2_data = np.apply_over_axes(np.sum, data_np, [1, 2, 3])  # 1D array
@@ -754,6 +756,9 @@ if __name__ == "__main__":
    ### extract the individual tensor (sequence) with reversed index (for lower level ranks)
    rank_npz_dic = {}

    end_time = datetime.now()
    print0(f"Execution time: {(end_time - start_time).total_seconds()} seconds")

    if rank == gid * (int(scale_lev1) + 1):  ## roof rank
        pred_te_seq = te_raw_predictions['prediction'][0, :, 3]  ## single sequence for roof rank (for now)
        rank_npz_dic['roof'] = pred_te_seq.cpu().numpy() ### temp key for this roof rank