Commit b938ec51 authored by Jared Casper's avatar Jared Casper
Browse files

Tasks seems to be working.

parent 3b91262e
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -58,7 +58,7 @@ class MultipleChoice(MegatronModule):
                                                     init_method)
            self._multichoice_head_key = 'multichoice_head'

    def set_input_tensor(self, input_tensor)
    def set_input_tensor(self, input_tensor):
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, model_input, attention_mask, tokentype_ids=None):
@@ -127,4 +127,3 @@ class MultipleChoice(MegatronModule):
                print_rank_last('***WARNING*** could not find {} in the checkpoint, '
                                'initializing to random'.format(
                                    self._multichoice_head_key))
+13 −0
Original line number Diff line number Diff line
@@ -24,6 +24,18 @@ from megatron import mpu
from megatron import p2p_communication


def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func


def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
    """Forward step for passed-in model.

@@ -34,6 +46,7 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
    timers = get_timers()

    timers('forward-compute').start()
    # TODO
    model.module.module.set_input_tensor(input_tensor)
    output_tensor, loss_func = forward_step_func(data_iterator, model)
    if mpu.is_pipeline_last_stage():
+11 −41
Original line number Diff line number Diff line
@@ -26,9 +26,8 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.training import communicate
from megatron.utils import get_ltor_masks_and_position_ids

from megatron.p2p_communication import recv_forward, send_forward

def get_batch(context_tokens):
    """Generate batch from context tokens."""
@@ -395,55 +394,26 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

    # Hidden size changes when not using recompute, need to tell communicate()
    # the correct size
    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

    if not mpu.is_pipeline_first_stage():
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
    else:
        input_tensor = None
    input_tensor = recv_forward()

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
    model.set_input_tensor(input_tensor)
    output_tensor = model(tokens, position_ids, attention_mask,
                          tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)
        else:
            output_tensor = model(tokens, position_ids, attention_mask,
                                  tokentype_ids=tokentype_ids,
                                  layer_past=layer_past,
                                  get_key_value=get_key_value)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask,
                              layer_past=layer_past,
                              get_key_value=get_key_value,
                              forward_method_parallel_output=forward_method_parallel_output)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask,
                              layer_past=layer_past,
                              get_key_value=get_key_value)

    if get_key_value:
        output_tensor, layer_past = output_tensor

    if not mpu.is_pipeline_last_stage():
        communicate(tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)
    send_forward(output_tensor)

    args.seq_length = orig_seq_length
    if get_key_value:
+69 −51
Original line number Diff line number Diff line
@@ -17,13 +17,14 @@

import os
import time
from functools import partial

import torch

from megatron import get_args
from megatron import print_rank_last, is_last_rank
from megatron import mpu
from megatron.training import communicate
from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch

@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider):
    for datapath in datapaths:
        dataset = single_dataset_provider(datapath)
        dataloader = build_data_loader(
            dataset, args.micro_batch_size, num_workers=args.num_workers,
            dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))

@@ -73,14 +74,61 @@ def accuracy_func_provider(single_dataset_provider):

    return metrics_func


def calculate_correct_answers(name, model, dataloader,
                              epoch, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
    args = get_args()
    forward_backward_func = get_forward_backward_func()
    start_time = time.time()
    model.eval()
    saved_batch_size = args.micro_batch_size
    for m in model:
        m.eval()
    saved_micro_batch_size = args.micro_batch_size
    saved_global_batch_size = args.global_batch_size

    ds = dataloader.dataset
    if hasattr(ds, 'sample_multiplier'):
        sample_multiplier = ds.sample_multiplier
    else:
        sample_multiplier = 1
    micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
    num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel

    def loss_func(output_predictions, labels, output_tensor):
        logits = output_tensor

        loss_dict = {}
        # Add output predictions.
        if output_predictions:
            assert False
            loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
                logits.float()).data.cpu().numpy().tolist()
            loss_dict['labels'] = labels.data.cpu().numpy().tolist()
            loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
        # Compute the correct answers.
        predicted = torch.argmax(logits, dim=-1)
        corrects = (predicted == labels)
        # Add to the counters.
        loss_dict['total'] = labels.size(0)
        loss_dict['correct'] = corrects.sum().item()

        return 0, loss_dict

    # defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        tokens, types, labels, attention_mask = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)

        return output_tensor, partial(loss_func, output_predictions, labels)

    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
@@ -92,60 +140,30 @@ def calculate_correct_answers(name, model, dataloader,
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
            # Run the model forward.
            tokens, types, labels_, attention_mask = process_batch(batch)

            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data
            actual_batch_size = len(labels_)
            actual_batch_size = len(batch['label'])
            # ... applying sample_multiplier if necessary
            ds = dataloader.dataset
            if hasattr(ds, 'sample_multiplier'):
                actual_batch_size *= ds.sample_multiplier
            args.micro_batch_size = actual_batch_size

            if not mpu.is_pipeline_first_stage():
                input_tensor, _ = communicate(
                    tensor_send_next=None,
                    tensor_send_prev=None,
                    recv_forward=True,
                    recv_backward=False)
            else:
                input_tensor = None
            args.micro_batch_size = actual_batch_size * sample_multiplier
            args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches

            # Forward model.
            if mpu.is_pipeline_first_stage():
                assert input_tensor is None
                output_tensor = model(tokens, attention_mask, tokentype_ids=types)
            else:
                assert input_tensor is not None
                output_tensor = model(input_tensor, attention_mask)
            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
                                               optimizer=None, timers=None, forward_only=True)

            if mpu.is_pipeline_last_stage():
                logits = output_tensor

                # Add output predictions.
            for loss_dict in loss_dicts:
                if output_predictions:
                    softmaxes.extend(torch.nn.Softmax(dim=-1)(
                        logits.float()).data.cpu().numpy().tolist())
                    labels.extend(labels_.data.cpu().numpy().tolist())
                    ids.extend(batch['uid'].cpu().numpy().tolist())
                # Compute the correct answers.
                predicted = torch.argmax(logits, dim=-1)
                corrects = (predicted == labels_)
                # Add to the counters.
                total += labels_.size(0)
                correct += corrects.sum().item()
            else:
                communicate(
                    tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)

    model.train()
    args.micro_batch_size = saved_batch_size
                    softmaxes.extend(loss_dict['softmaxes'])
                    labels.extend(loss_dict['labels'])
                    ids.extend(loss_dict['ids'])
                total += loss_dict['total']
                correct += loss_dict['correct']


    for m in model:
        m.train()
    args.micro_batch_size = saved_micro_batch_size
    args.global_batch_size = saved_global_batch_size

    # Reduce.
    if mpu.is_pipeline_last_stage():
+25 −25
Original line number Diff line number Diff line
@@ -15,6 +15,8 @@

"""Finetune utilities."""

from functools import partial

import torch

from megatron import get_args
@@ -46,7 +48,20 @@ def process_batch(batch):
    return tokens, types, labels, attention_mask


def _cross_entropy_forward_step(batch, model, input_tensor):
def cross_entropy_loss_func(labels, output_tensor):
    logits = output_tensor

    # Cross-entropy loss.
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(logits.contiguous().float(), labels)

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}


def _cross_entropy_forward_step(batch, model):
    """Simple forward step with cross-entropy loss."""
    timers = get_timers()

@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
    timers('batch-generator').stop()

    # Forward model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
    output_tensor = model(tokens, attention_mask, tokentype_ids=types)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask)

    if mpu.is_pipeline_last_stage():
        logits = output_tensor

        # Cross-entropy loss.
        loss_func = torch.nn.CrossEntropyLoss()
        loss = loss_func(logits.contiguous().float(), labels)

        # Reduce loss for logging.
        averaged_loss = average_losses_across_data_parallel_group([loss])

        return loss, {'lm loss': averaged_loss[0]}
    return output_tensor
    return output_tensor, partial(cross_entropy_loss_func, labels)


def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
@@ -135,6 +134,8 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
    # This is necessary so pipeline transfers know what size they are
    # and the LR schedule, which is based on samples seen, gets set
    # correctly.
    args.orig_micro_batch_size = args.micro_batch_size
    args.orig_global_batch_size = args.global_batch_size
    if hasattr(train_dataset, 'sample_multiplier'):
        args.micro_batch_size *= train_dataset.sample_multiplier
        args.global_batch_size *= train_dataset.sample_multiplier
@@ -149,7 +150,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
    timers = get_timers()

    # Turn on training mode which enables dropout.
    model.train()
    for m in model:
        m.train()

    # Tracking loss.
    losses_dict_sum = {}
@@ -180,10 +182,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
            start_iteration = 0

            # Train for one step.
            losses_dict, skipped_iter, grad_norm = train_step(forward_step,
                                                              batch, model,
                                                              optimizer,
                                                              lr_scheduler)
            out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
            iteration += 1

            # Logging.
@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                              iteration,
                                              optimizer.get_loss_scale().item(),
                                              report_memory_flag, skipped_iter,
                                              grad_norm, params_norm)
                                              grad_norm, params_norm, num_zeros_in_grad)

            # Autoresume
            if args.adlr_autoresume and \
Loading