Commit aaa7aa32 authored by zihanl's avatar zihanl
Browse files

remove finetune part

parent a87777bf
Loading
Loading
Loading
Loading
+1 −111
Original line number Diff line number Diff line
@@ -2,116 +2,10 @@
"""Model evaluation"""

from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.checkpointing import load_checkpoint
from tasks.finetune_utils import build_data_loader
from tasks.knwl_dialo.data import build_test_dataset
from tasks.knwl_dialo.data import build_test_dataset_for_prompting
from tasks.knwl_dialo.finetune import model_provider 
from tasks.knwl_dialo.finetune import process_batch 
from tasks.knwl_dialo.finetune import loss_func 
from tasks.knwl_dialo.finetune import forward_step 
from tasks.knwl_dialo.metrics import F1Metric
from tqdm import tqdm

def test_dataset_provider():
    """Build the test dataset"""
    args = get_args()
    print_rank_0('> building the test dataset for %s module ...' \
                    % args.module)

    if args.prompt_type != "":
        print_rank_0('> evaluating ppl for prompting')
        test_ds = build_test_dataset_for_prompting(
            test_data_path=args.test_data_path,
            prompt_file=args.prompt_file,
            module=args.module,
            max_seq_len=args.seq_length,
            num_prompt_examples=args.num_prompt_examples,
            three_turns=args.three_turns,
            dynamic_prompt=args.dynamic_prompt)

    else:
        print_rank_0('> evaluating ppl for finetuning')
        test_ds = build_test_dataset(
            test_data_path=args.test_data_path,
            module=args.module,
            max_seq_len=args.seq_length,
            last_turn=args.last_turn,
            no_control_code=args.no_control_code,
            add_separator=args.add_separator,
            add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
            remove_ctrl_sent=args.remove_ctrl_sent)

    print_rank_0("> finished creating the test dataset for %s module ..." \
                    % args.module)

    print_rank_0('> test set size: %d' % len(test_ds))
    args.eval_iters = len(test_ds) // args.global_batch_size
    print_rank_0('> evaluation iteration: %d' % args.eval_iters)

    return test_ds


def _build_test_iterator(test_dataset, task_collate_fn=None):
    """Test dataloader."""
    args = get_args()

    print_rank_0('building test dataloader ...')
    # Test loader
    test_dataloader = build_data_loader(test_dataset, args.micro_batch_size,
                                        args.num_workers, not args.keep_last,
                                        task_collate_fn)
    test_iterator = test_dataloader.__iter__()
    return test_iterator


def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
    """Evaluating perplexity"""
    args = get_args()
    timers = get_timers()

    # test dataloader.
    timers('test dataset/dataloder').start()
    test_dataset = test_dataset_provider()
    test_iterator = _build_test_iterator(test_dataset)
    timers('test dataset/dataloder').stop()

    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()

    timers('pretrained checkpoint').start()
    if args.pretrained_checkpoint is not None:
        original_load = args.load
        args.load = args.pretrained_checkpoint
        original_rng = args.no_load_rng
        args.no_load_rng = True
        iteration = load_checkpoint(model, None, None)
        args.load = original_load
        args.no_load_rng = original_rng
        # This is critical when only model is loaded. We should make sure
        # main parameters are also updated.
        optimizer.reload_model_params()
    timers('pretrained checkpoint').stop()

    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log(['test dataset/dataloder', 'model and optimizer', 
                'pretrained checkpoint'])
    
    print_rank_0('evaluating ...')
    prefix = 'iteration {}'.format(iteration)
    evaluate_and_print_results(prefix, forward_step, 
                               test_iterator, model,
                               iteration, False)
    
    print_rank_0('done :-)')


def evaluate_f1(guess_file, answer_file):
    """Evaluating F1 Score"""
@@ -146,9 +40,5 @@ def evaluate_f1(guess_file, answer_file):
def main():
    args = get_args()
    
    if 'PPL' in args.task: 
        evaluate_ppl(test_dataset_provider, model_provider, forward_step)
    
    elif 'F1' in args.task:
    evaluate_f1(args.guess_file, args.answer_file)
+0 −31
Original line number Diff line number Diff line
@@ -12,37 +12,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module


def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
    """
    Build attention masks and position id for left to right model.
    Different from the existing get_ltor_masks_and_position_ids function,
    we add padding to the input sequences to make sure their lengths are the same.
    """

    micro_batch_size, seq_length = data.size()

    # Attention mask
    attention_mask = torch.tril(torch.ones(
        (micro_batch_size, seq_length, seq_length), device=data.device)).view(
            micro_batch_size, 1, seq_length, seq_length)

    # mask padded tokens
    for b in range(micro_batch_size):
        for idx in range(seq_length-1):
            if data[b, idx] == eod_token_id:
                # pad tokens that come after the eod token
                attention_mask[b, 0, idx+1:, :] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    
    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)

    return attention_mask, position_ids


def switch(val1, val2, boolean):
    """Return either val1 or val2 depending on boolean"""