Commit 67e85f7e authored by Wang, Xiao's avatar Wang, Xiao
Browse files

add comments for megatron training

parent 1aa66511
Loading
Loading
Loading
Loading
Loading
+20 −6
Original line number Diff line number Diff line
@@ -122,7 +122,8 @@ def pretrain(train_valid_test_dataset_provider,
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')

    # Data stuff.
    #XW: build train, valid and test datasets iterator.
    #XW: for pubmed dataset, 24000 training data. 360 validation data. 120 test data 
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
        all_data_iterators = [
@@ -351,9 +352,9 @@ def setup_model_and_optimizer(model_provider_func, model_type):

    unwrapped_model = unwrap_model(model,
                                   (torchDDP, LocalDDP, Float16Module))
    optimizer = get_megatron_optimizer(unwrapped_model)
    optimizer = get_megatron_optimizer(unwrapped_model)   #XW:get optimizer. default is adam.

    lr_scheduler = get_learning_rate_scheduler(optimizer)
    lr_scheduler = get_learning_rate_scheduler(optimizer)   #XW: it has annealing learning rate decay

    if args.load is not None:
        timers = get_timers()
@@ -361,7 +362,7 @@ def setup_model_and_optimizer(model_provider_func, model_type):
        # max time.
        torch.distributed.barrier()
        timers('load-checkpoint').start()
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)   #XW: load checkpoint model and return iterations
        torch.distributed.barrier()
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
@@ -369,9 +370,12 @@ def setup_model_and_optimizer(model_provider_func, model_type):
        args.iteration = 0

    # We only support local DDP with multiple micro-batches.
    #XW: model and unwrapped model are list object. model[0] are BertModel or GPTModel object.

    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
        assert args.DDP_impl == 'local'


    # get model without FP16 and/or TorchDDP wrappers
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
@@ -672,20 +676,22 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
    # Write args to tensorboard
    write_args_to_tensorboard()

    # Turn on training mode which enables dropout.
    # XW: Turn on training mode which enables dropout. dropout behaves differently for training and evaluation. Must be turned off for evaluation
    for model_module in model:
        model_module.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    # Iteration is 0 for new training. It's not zero if loaded from checkpoint
    iteration = args.iteration

    timers('interval-time').start()
    print_datetime('before the start of training step')
    report_memory_flag = True
    while iteration < args.train_iters:

        #XW: when rampup_batch_size is set to NONE, the number of microbatches are constant. Otherwise, it will ramp up
        update_num_microbatches(args.consumed_train_samples)
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
@@ -693,6 +699,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                       model,
                       optimizer,
                        lr_scheduler)
        
        print("loss_dict.shape ",loss_dict.shape," loss_dict ",loss_dict,flush=True)
        print("skipped_iter ",skipped_iter,flush=True)
        print("grad_norm.shape ",grad_norm.shape,flush=True)
        print("num_zeros_in_grad ",num_zeros_in_grad,flush=True)

        #print_rank_0('End train_step')
        iteration += 1
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
@@ -854,6 +866,8 @@ def cyclic_iter(iter):
        for x in iter:
            yield x

#XW: build train valid test data iterators

def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""