Commit 2ede8235 authored by mohammad's avatar mohammad
Browse files

testing

parent 5897a790
Loading
Loading
Loading
Loading
+12 −2
Original line number Diff line number Diff line
@@ -25,6 +25,9 @@ from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal

from megatron.utils import report_memory
from megatron import mpu


def gpt2_attention_mask_func(attention_scores, ltor_mask):
    attention_scores.masked_fill_(ltor_mask, -10000.0)
@@ -48,7 +51,7 @@ class GPT2Model(MegatronModule):
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))

    def forward(self, input_ids, position_ids, attention_mask,
    def forward(self, input_ids, position_ids, attention_mask, labels,
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):

@@ -75,7 +78,14 @@ class GPT2Model(MegatronModule):
        if get_key_value:
            output = [output, presents]

        return output
        #report_memory('AAA')

        losses = mpu.vocab_parallel_cross_entropy(output, labels)

        #report_memory('BBB')

        #return output
        return losses

    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
+1 −0
Original line number Diff line number Diff line
@@ -379,6 +379,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                                          optimizer.param_groups[0]['lr'],
                                          iteration, loss_scale,
                                          report_memory_flag)
        #report_memory_flag = True

        # Autoresume
        if args.adlr_autoresume and \
+7 −4
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ from megatron.model import GPT2Model
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses

from megatron.utils import report_memory

def model_provider():
    """Build the model."""
@@ -81,9 +81,12 @@ def forward_step(data_iterator, model):
    timers('batch generator').stop()

    # Forward model.
    output = model(tokens, position_ids, attention_mask)
    losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                              labels)
    losses = model(tokens, position_ids, attention_mask, labels)
    #report_memory('CCC')
    #exit()
    #losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
    #                                          labels)
    #report_memory('DDD')
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()