Loading megatron/arguments.py +4 −0 Original line number Diff line number Diff line Loading @@ -294,6 +294,10 @@ def _add_mixed_precision_args(parser): help='Window over which to raise/lower dynamic scale.') group.add_argument('--min-scale', type=float, default=1, help='Minimum loss scale for dynamic loss scale.') group.add_argument('--fp16-lm-cross-entropy', action='store_true', help='Move the cross entropy unreduced loss calculation' 'for lm head to fp16.') return parser Loading megatron/model/gpt2_model.py +7 −11 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import torch from megatron import get_args from megatron import mpu from megatron.module import MegatronModule from .language_model import parallel_lm_logits Loading @@ -25,9 +26,6 @@ from .language_model import get_language_model from .utils import init_method_normal from .utils import scaled_init_method_normal from megatron.utils import report_memory from megatron import mpu def gpt2_attention_mask_func(attention_scores, ltor_mask): attention_scores.masked_fill_(ltor_mask, -10000.0) Loading @@ -51,7 +49,7 @@ class GPT2Model(MegatronModule): scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) def forward(self, input_ids, position_ids, attention_mask, labels, def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): Loading @@ -78,14 +76,12 @@ class GPT2Model(MegatronModule): if get_key_value: output = [output, presents] #report_memory('AAA') losses = mpu.vocab_parallel_cross_entropy(output, labels) #report_memory('BBB') if labels is not None: return output else: loss = mpu.vocab_parallel_cross_entropy(output, labels) return loss #return output return losses def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): Loading megatron/training.py +0 −1 Original line number Diff line number Diff line Loading @@ -379,7 +379,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler, optimizer.param_groups[0]['lr'], iteration, loss_scale, report_memory_flag) #report_memory_flag = True # Autoresume if args.adlr_autoresume and \ Loading pretrain_gpt2.py +9 −7 Original line number Diff line number Diff line Loading @@ -27,7 +27,7 @@ from megatron.model import GPT2Model from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import reduce_losses from megatron.utils import report_memory def model_provider(): """Build the model.""" Loading Loading @@ -72,6 +72,7 @@ def get_batch(data_iterator): def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. Loading @@ -81,12 +82,13 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. losses = model(tokens, position_ids, attention_mask, labels) #report_memory('CCC') #exit() #losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), # labels) #report_memory('DDD') if args.fp16_lm_cross_entropy: losses = model(tokens, position_ids, attention_mask, labels=labels) else: output = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() Loading Loading
megatron/arguments.py +4 −0 Original line number Diff line number Diff line Loading @@ -294,6 +294,10 @@ def _add_mixed_precision_args(parser): help='Window over which to raise/lower dynamic scale.') group.add_argument('--min-scale', type=float, default=1, help='Minimum loss scale for dynamic loss scale.') group.add_argument('--fp16-lm-cross-entropy', action='store_true', help='Move the cross entropy unreduced loss calculation' 'for lm head to fp16.') return parser Loading
megatron/model/gpt2_model.py +7 −11 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import torch from megatron import get_args from megatron import mpu from megatron.module import MegatronModule from .language_model import parallel_lm_logits Loading @@ -25,9 +26,6 @@ from .language_model import get_language_model from .utils import init_method_normal from .utils import scaled_init_method_normal from megatron.utils import report_memory from megatron import mpu def gpt2_attention_mask_func(attention_scores, ltor_mask): attention_scores.masked_fill_(ltor_mask, -10000.0) Loading @@ -51,7 +49,7 @@ class GPT2Model(MegatronModule): scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) def forward(self, input_ids, position_ids, attention_mask, labels, def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): Loading @@ -78,14 +76,12 @@ class GPT2Model(MegatronModule): if get_key_value: output = [output, presents] #report_memory('AAA') losses = mpu.vocab_parallel_cross_entropy(output, labels) #report_memory('BBB') if labels is not None: return output else: loss = mpu.vocab_parallel_cross_entropy(output, labels) return loss #return output return losses def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): Loading
megatron/training.py +0 −1 Original line number Diff line number Diff line Loading @@ -379,7 +379,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler, optimizer.param_groups[0]['lr'], iteration, loss_scale, report_memory_flag) #report_memory_flag = True # Autoresume if args.adlr_autoresume and \ Loading
pretrain_gpt2.py +9 −7 Original line number Diff line number Diff line Loading @@ -27,7 +27,7 @@ from megatron.model import GPT2Model from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import reduce_losses from megatron.utils import report_memory def model_provider(): """Build the model.""" Loading Loading @@ -72,6 +72,7 @@ def get_batch(data_iterator): def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. Loading @@ -81,12 +82,13 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. losses = model(tokens, position_ids, attention_mask, labels) #report_memory('CCC') #exit() #losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), # labels) #report_memory('DDD') if args.fp16_lm_cross_entropy: losses = model(tokens, position_ids, attention_mask, labels=labels) else: output = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() Loading