Loading tasks/finetune_utils.py +3 −1 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ from megatron import get_timers from megatron import mpu from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint from megatron.model import ModelType from megatron.training import evaluate_and_print_results from megatron.training import setup_model_and_optimizer from megatron.training import train_step Loading Loading @@ -248,6 +249,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, def finetune(train_valid_datasets_provider, model_provider, model_type=ModelType.encoder_or_decoder, forward_step=_cross_entropy_forward_step, end_of_epoch_callback_provider=None, task_collate_fn=None): Loading Loading @@ -277,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider, # Build model, optimizer and learning rate scheduler. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type) timers('model and optimizer').stop() # If pretrained checkpoint is provided and we have not trained for Loading Loading
tasks/finetune_utils.py +3 −1 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ from megatron import get_timers from megatron import mpu from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint from megatron.model import ModelType from megatron.training import evaluate_and_print_results from megatron.training import setup_model_and_optimizer from megatron.training import train_step Loading Loading @@ -248,6 +249,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, def finetune(train_valid_datasets_provider, model_provider, model_type=ModelType.encoder_or_decoder, forward_step=_cross_entropy_forward_step, end_of_epoch_callback_provider=None, task_collate_fn=None): Loading Loading @@ -277,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider, # Build model, optimizer and learning rate scheduler. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type) timers('model and optimizer').stop() # If pretrained checkpoint is provided and we have not trained for Loading