Commit 3ae12a47 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'finetune-fix' into 'main'

Fix finetuning tasks after T5 pipeline merge.

See merge request ADLR/megatron-lm!343
parents f5345dfa a20445d3
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ from megatron import get_timers
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
@@ -248,6 +249,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,


def finetune(train_valid_datasets_provider, model_provider,
             model_type=ModelType.encoder_or_decoder,
             forward_step=_cross_entropy_forward_step,
             end_of_epoch_callback_provider=None,
             task_collate_fn=None):
@@ -277,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,

    # Build model, optimizer and learning rate scheduler.
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type)
    timers('model and optimizer').stop()

    # If pretrained checkpoint is provided and we have not trained for