Commit 83671bbf authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Address Deepak's comments

parent 8bed1d63
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -185,7 +185,7 @@ def parse_args(extra_args_provider=None, defaults={},
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
    if args.fp32_residual_connection:
        assert args.fp16, \
            'residual connection in fp32 only supports in fp16 mode.'
            'residual connection in fp32 only supported when using fp16.'
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.checkpoint_activations, \
+2 −0
Original line number Diff line number Diff line
@@ -568,8 +568,10 @@ class ParallelTransformer(MegatronModule):

        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
            # If the input flag for fp32 residual connection is set, convert for float.
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
            # Otherwise, leave it as is.
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()