Loading megatron/arguments.py +3 −0 Original line number Diff line number Diff line Loading @@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={}, if args.num_unique_layers < args.num_layers: assert args.DDP_impl == 'local', \ 'torch-DDP does not work with parameters sharing.' # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' _print_args(args) return args Loading megatron/model/bert_model.py +7 −1 Original line number Diff line number Diff line Loading @@ -115,6 +115,7 @@ class BertModel(MegatronModule): super(BertModel, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) Loading Loading @@ -170,7 +171,12 @@ class BertModel(MegatronModule): if lm_labels is None: return lm_logits, binary_logits else: if self.fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) else: lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) return lm_loss, binary_logits Loading megatron/model/gpt2_model.py +6 −1 Original line number Diff line number Diff line Loading @@ -40,6 +40,7 @@ class GPT2Model(MegatronModule): args = get_args() self.parallel_output = parallel_output self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.language_model, self._language_model_key = get_language_model( attention_mask_func=gpt2_attention_mask_func, Loading Loading @@ -79,7 +80,11 @@ class GPT2Model(MegatronModule): if labels is None: return output else: if self.fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss Loading pretrain_bert.py +5 −9 Original line number Diff line number Diff line Loading @@ -78,16 +78,12 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. lm_labels if args.fp16_lm_cross_entropy: lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) else: lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) lm_loss_ = mpu.vocab_parallel_cross_entropy( lm_logits.contiguous().float(), lm_labels.contiguous()) sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sentence_order.view(-1).contiguous(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) lm_loss = torch.sum( Loading pretrain_gpt2.py +2 −7 Original line number Diff line number Diff line Loading @@ -82,12 +82,7 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. if args.fp16_lm_cross_entropy: losses = model(tokens, position_ids, attention_mask, labels=labels) else: output = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() Loading Loading
megatron/arguments.py +3 −0 Original line number Diff line number Diff line Loading @@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={}, if args.num_unique_layers < args.num_layers: assert args.DDP_impl == 'local', \ 'torch-DDP does not work with parameters sharing.' # Mixed precision checks. if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' _print_args(args) return args Loading
megatron/model/bert_model.py +7 −1 Original line number Diff line number Diff line Loading @@ -115,6 +115,7 @@ class BertModel(MegatronModule): super(BertModel, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) Loading Loading @@ -170,7 +171,12 @@ class BertModel(MegatronModule): if lm_labels is None: return lm_logits, binary_logits else: if self.fp16_lm_cross_entropy: assert lm_logits.dtype == torch.half lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) else: lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_labels) return lm_loss, binary_logits Loading
megatron/model/gpt2_model.py +6 −1 Original line number Diff line number Diff line Loading @@ -40,6 +40,7 @@ class GPT2Model(MegatronModule): args = get_args() self.parallel_output = parallel_output self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.language_model, self._language_model_key = get_language_model( attention_mask_func=gpt2_attention_mask_func, Loading Loading @@ -79,7 +80,11 @@ class GPT2Model(MegatronModule): if labels is None: return output else: if self.fp16_lm_cross_entropy: assert output.dtype == torch.half loss = mpu.vocab_parallel_cross_entropy(output, labels) else: loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss Loading
pretrain_bert.py +5 −9 Original line number Diff line number Diff line Loading @@ -78,16 +78,12 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. lm_labels if args.fp16_lm_cross_entropy: lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) else: lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) lm_loss_ = mpu.vocab_parallel_cross_entropy( lm_logits.contiguous().float(), lm_labels.contiguous()) sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), sentence_order.view(-1).contiguous(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) lm_loss = torch.sum( Loading
pretrain_gpt2.py +2 −7 Original line number Diff line number Diff line Loading @@ -82,12 +82,7 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. if args.fp16_lm_cross_entropy: losses = model(tokens, position_ids, attention_mask, labels=labels) else: output = model(tokens, position_ids, attention_mask) losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() Loading