Commit 5f174c07 authored by Mohammad's avatar Mohammad
Browse files

fp32 fixes

parent 1c1a55da
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -102,6 +102,7 @@ class ParallelSelfAttention(MegatronModule):
                 output_layer_init_method, layer_number):
        super(ParallelSelfAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16

        self.attention_mask_func = attention_mask_func
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
@@ -244,7 +245,7 @@ class ParallelSelfAttention(MegatronModule):
            query_layer, key_layer)

        # fp32 conversion.
        if self.attention_softmax_in_fp32:
        if self.fp16 and self.attention_softmax_in_fp32:
            attention_scores = attention_scores.float()

        # Apply attention mask. [b, np, s, s]
@@ -267,7 +268,7 @@ class ParallelSelfAttention(MegatronModule):
        attention_probs = self._get_attention_probs(attention_scores)

        # fp16 conversion
        if self.attention_softmax_in_fp32:
        if self.fp16 and self.attention_softmax_in_fp32:
            attention_probs = attention_probs.half()

        # Context layer. [b, s, hp]
+4 −1
Original line number Diff line number Diff line
@@ -379,9 +379,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
        iteration += 1

        # Logging.
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
                                          iteration, optimizer.loss_scale,
                                          iteration, loss_scale,
                                          report_memory_flag)

        # Autoresume