Commit e7c7a78f authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'finetune_fix' into 'main'

Update code used for finetuning to latest API.

See merge request ADLR/megatron-lm!174
parents b4b0d739 b219ff00
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -52,8 +52,7 @@ class Classification(MegatronModule):

    def forward(self, input_ids, attention_mask, tokentype_ids):

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
+1 −2
Original line number Diff line number Diff line
@@ -64,8 +64,7 @@ class MultipleChoice(MegatronModule):
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
        tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
+3 −3
Original line number Diff line number Diff line
@@ -161,7 +161,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
            start_iteration = 0

            # Train for one step.
            losses_dict, _ = train_step(forward_step, batch, model,
            losses_dict, skipped_iter = train_step(forward_step, batch, model,
                                                   optimizer, lr_scheduler)
            iteration += 1

@@ -169,7 +169,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
            report_memory_flag = training_log(losses_dict, losses_dict_sum,
                                              optimizer.param_groups[0]['lr'],
                                              iteration, optimizer.loss_scale,
                                              report_memory_flag)
                                              report_memory_flag, skipped_iter)

            # Autoresume
            if args.adlr_autoresume and \