Loading megatron/arguments.py +7 −0 Original line number Diff line number Diff line Loading @@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={}, if args.num_unique_layers < args.num_layers: assert args.DDP_impl == 'local', \ 'torch-DDP does not work with parameters sharing.' # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' _print_args(args) return args Loading Loading @@ -294,6 +297,10 @@ def _add_mixed_precision_args(parser): help='Window over which to raise/lower dynamic scale.') group.add_argument('--min-scale', type=float, default=1, help='Minimum loss scale for dynamic loss scale.') group.add_argument('--fp16-lm-cross-entropy', action='store_true', help='Move the cross entropy unreduced loss calculation' 'for lm head to fp16.') return parser Loading megatron/model/bert_model.py +15 −2 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import torch from megatron import get_args from megatron import mpu from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model.transformer import LayerNorm Loading Loading @@ -114,6 +115,7 @@ class BertModel(MegatronModule): super(BertModel, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) Loading @@ -138,7 +140,8 @@ class BertModel(MegatronModule): init_method) self._binary_head_key = 'binary_head' def forward(self, input_ids, attention_mask, tokentype_ids=None): def forward(self, input_ids, attention_mask, tokentype_ids=None, lm_labels=None): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) Loading @@ -161,11 +164,21 @@ class BertModel(MegatronModule): lm_logits = self.lm_head( lm_output, self.language_model.embedding.word_embeddings.weight) binary_logits = None if self.add_binary_head: binary_logits = self.binary_head(pooled_output) if lm_labels is None: return lm_logits, binary_logits else: if self.fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) else: lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) return lm_loss, binary_logits return lm_logits, None def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): Loading megatron/model/gpt2_model.py +13 −2 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import torch from megatron import get_args from megatron import mpu from megatron.module import MegatronModule from .language_model import parallel_lm_logits Loading @@ -39,6 +40,7 @@ class GPT2Model(MegatronModule): args = get_args() self.parallel_output = parallel_output self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.language_model, self._language_model_key = get_language_model( attention_mask_func=gpt2_attention_mask_func, Loading @@ -48,7 +50,7 @@ class GPT2Model(MegatronModule): scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): Loading @@ -75,7 +77,16 @@ class GPT2Model(MegatronModule): if get_key_value: output = [output, presents] if labels is None: return output else: if self.fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): Loading pretrain_bert.py +7 −6 Original line number Diff line number Diff line Loading @@ -68,6 +68,7 @@ def get_batch(data_iterator): def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. Loading @@ -76,15 +77,15 @@ def forward_step(data_iterator, model): = get_batch(data_iterator) timers('batch generator').stop() # Forward model. lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) # Forward model. lm_labels lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sentence_order.view(-1).contiguous(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(), lm_labels.contiguous()) lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() Loading pretrain_gpt2.py +3 −3 Original line number Diff line number Diff line Loading @@ -72,6 +72,7 @@ def get_batch(data_iterator): def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. Loading @@ -81,9 +82,8 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. output = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) losses = model(tokens, position_ids, attention_mask, labels=labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() Loading Loading
megatron/arguments.py +7 −0 Original line number Diff line number Diff line Loading @@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={}, if args.num_unique_layers < args.num_layers: assert args.DDP_impl == 'local', \ 'torch-DDP does not work with parameters sharing.' # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' _print_args(args) return args Loading Loading @@ -294,6 +297,10 @@ def _add_mixed_precision_args(parser): help='Window over which to raise/lower dynamic scale.') group.add_argument('--min-scale', type=float, default=1, help='Minimum loss scale for dynamic loss scale.') group.add_argument('--fp16-lm-cross-entropy', action='store_true', help='Move the cross entropy unreduced loss calculation' 'for lm head to fp16.') return parser Loading
megatron/model/bert_model.py +15 −2 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import torch from megatron import get_args from megatron import mpu from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model.transformer import LayerNorm Loading Loading @@ -114,6 +115,7 @@ class BertModel(MegatronModule): super(BertModel, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) Loading @@ -138,7 +140,8 @@ class BertModel(MegatronModule): init_method) self._binary_head_key = 'binary_head' def forward(self, input_ids, attention_mask, tokentype_ids=None): def forward(self, input_ids, attention_mask, tokentype_ids=None, lm_labels=None): extended_attention_mask = bert_extended_attention_mask( attention_mask, next(self.language_model.parameters()).dtype) Loading @@ -161,11 +164,21 @@ class BertModel(MegatronModule): lm_logits = self.lm_head( lm_output, self.language_model.embedding.word_embeddings.weight) binary_logits = None if self.add_binary_head: binary_logits = self.binary_head(pooled_output) if lm_labels is None: return lm_logits, binary_logits else: if self.fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) else: lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) return lm_loss, binary_logits return lm_logits, None def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): Loading
megatron/model/gpt2_model.py +13 −2 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ import torch from megatron import get_args from megatron import mpu from megatron.module import MegatronModule from .language_model import parallel_lm_logits Loading @@ -39,6 +40,7 @@ class GPT2Model(MegatronModule): args = get_args() self.parallel_output = parallel_output self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.language_model, self._language_model_key = get_language_model( attention_mask_func=gpt2_attention_mask_func, Loading @@ -48,7 +50,7 @@ class GPT2Model(MegatronModule): scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers)) def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): Loading @@ -75,7 +77,16 @@ class GPT2Model(MegatronModule): if get_key_value: output = [output, presents] if labels is None: return output else: if self.fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): Loading
pretrain_bert.py +7 −6 Original line number Diff line number Diff line Loading @@ -68,6 +68,7 @@ def get_batch(data_iterator): def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. Loading @@ -76,15 +77,15 @@ def forward_step(data_iterator, model): = get_batch(data_iterator) timers('batch generator').stop() # Forward model. lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) # Forward model. lm_labels lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sentence_order.view(-1).contiguous(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(), lm_labels.contiguous()) lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() Loading
pretrain_gpt2.py +3 −3 Original line number Diff line number Diff line Loading @@ -72,6 +72,7 @@ def get_batch(data_iterator): def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. Loading @@ -81,9 +82,8 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. output = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) losses = model(tokens, position_ids, attention_mask, labels=labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() Loading