Loading megatron/arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -185,7 +185,7 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' if args.fp32_residual_connection: assert args.fp16, \ 'residual connection in fp32 only supports in fp16 mode.' 'residual connection in fp32 only supported when using fp16.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.checkpoint_activations, \ Loading megatron/model/transformer.py +2 −0 Original line number Diff line number Diff line Loading @@ -568,8 +568,10 @@ class ParallelTransformer(MegatronModule): if mpu.is_pipeline_first_stage(): # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # If the input flag for fp32 residual connection is set, convert for float. if self.fp32_residual_connection: hidden_states = hidden_states.transpose(0, 1).contiguous().float() # Otherwise, leave it as is. else: hidden_states = hidden_states.transpose(0, 1).contiguous() Loading Loading
megatron/arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -185,7 +185,7 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' if args.fp32_residual_connection: assert args.fp16, \ 'residual connection in fp32 only supports in fp16 mode.' 'residual connection in fp32 only supported when using fp16.' # Activation checkpointing. if args.distribute_checkpointed_activations: assert args.checkpoint_activations, \ Loading
megatron/model/transformer.py +2 −0 Original line number Diff line number Diff line Loading @@ -568,8 +568,10 @@ class ParallelTransformer(MegatronModule): if mpu.is_pipeline_first_stage(): # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # If the input flag for fp32 residual connection is set, convert for float. if self.fp32_residual_connection: hidden_states = hidden_states.transpose(0, 1).contiguous().float() # Otherwise, leave it as is. else: hidden_states = hidden_states.transpose(0, 1).contiguous() Loading