Commit 01fc0833 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'vit_pipeline_fixes' into 'main'

vit pipeline fixes

See merge request ADLR/megatron-lm!276
parents 217f54b3 ccae9dbd
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -60,8 +60,8 @@ def check_checkpoint_args(checkpoint_args):
    _compare('num_layers')
    _compare('hidden_size')
    _compare('num_attention_heads')
    _compare('max_position_embeddings')
    if args.vocab_file:
        _compare('max_position_embeddings')
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
+76 −52
Original line number Diff line number Diff line
@@ -50,11 +50,11 @@ class VitMlpHead(MegatronModule):
    def forward(self, hidden_states, sequence_index=0):
        # hidden_states: [b, s, h]
        # sequence_index: index of the token to pool.
        x = hidden_states[:, sequence_index, :]
        x = self.dense_in(x)
        x = torch.tanh(x)
        x = self.dense_out(x)
        return x
        hidden_state = hidden_states[:, sequence_index, :]
        dense_in_result = self.dense_in(hidden_state)
        tanh_result = torch.tanh(dense_in_result)
        dense_out_result = self.dense_out(tanh_result)
        return dense_out_result


def twod_interpolate_position_embeddings_hook(
@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook(
class VitModel(MegatronModule):
    """Vision Transformer Model."""

    def __init__(self, num_classes, finetune=False):
        super(VitModel, self).__init__()
    def __init__(self, 
                 num_classes,
                 finetune=False,
                 pre_process=True,
                 post_process=True):
        super(VitModel, self).__init__(share_word_embeddings=False)
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
@@ -136,6 +140,8 @@ class VitModel(MegatronModule):
                args.init_method_std, args.num_layers
            )

        self.pre_process = pre_process
        self.post_process = post_process
        self.hidden_size = args.hidden_size
        self.num_classes = num_classes
        self.patch_dim = args.patch_dim
@@ -148,8 +154,11 @@ class VitModel(MegatronModule):
        self.seq_length = self.num_patches + 1
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels

        if self.pre_process:
            # cls_token
        self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
            self.cls_token = torch.nn.Parameter(
                torch.randn(1, 1, self.hidden_size)
            )
            torch.nn.init.zeros_(self.cls_token)

            # Linear encoder
@@ -174,9 +183,13 @@ class VitModel(MegatronModule):

        # Transformer
        self.transformer = ParallelTransformer(
            self.init_method, self.scaled_init_method
            self.init_method, 
            self.scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process
        )

        if self.post_process:
            # MLP head
            if not self.finetune:
                self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
@@ -185,26 +198,37 @@ class VitModel(MegatronModule):
                    self.hidden_size, num_classes, torch.nn.init.zeros_
                )

    def forward(self, x):
        x = einops.rearrange(
            x,
    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.transformer.set_input_tensor(input_tensor)

    def forward(self, input):

        if self.pre_process:
            rearranged_input = einops.rearrange(
                input,
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=self.patch_dim,
                p2=self.patch_dim,
            )

        assert x.dtype == torch.half
        x = self.linear_encoder(x)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
            assert rearranged_input.dtype == torch.half
            encoder_output = self.linear_encoder(rearranged_input)
            cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
            concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)

            token_embeddings = concatenated_tokens + \
                self.position_embeddings(self.position_ids)
            hidden_states = self.embedding_dropout(token_embeddings)
        else:
            hidden_states = input

        x = x + self.position_embeddings(self.position_ids)
        x = self.embedding_dropout(x)
        x = self.transformer(x, None)
        hidden_states = self.transformer(hidden_states, None)

        if self.post_process:
            if not self.finetune:
            x = self.mlp_head(x)
                hidden_states = self.mlp_head(hidden_states)
            else:
            x = self.class_head(x[:, 0, :])
                hidden_states = self.class_head(hidden_states[:, 0, :])

        return x
        return hidden_states
+20 −14
Original line number Diff line number Diff line
@@ -17,19 +17,22 @@

import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vit_model import VitModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group

def model_provider():
def model_provider(pre_process=True, post_process=True):
    """Build the model."""

    print_rank_0("building VIT model ...")
    args = get_args()

    model = VitModel(num_classes=args.num_classes)
    model = VitModel(num_classes=args.num_classes,
                     pre_process=pre_process,
                     post_process=post_process)
    return model

def get_batch(data_iterator):
@@ -42,10 +45,21 @@ def get_batch(data_iterator):

    return images, labels

def forward_step(data_iterator, model, input_tensor):
def loss_func(labels, output_tensor):
    logits = output_tensor.contiguous().float()
    loss = F.cross_entropy(logits, labels)

    outputs = torch.argmax(logits, -1)
    correct = (outputs == labels).float()
    accuracy = torch.mean(correct)

    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])

    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}

def forward_step(data_iterator, model):
    """Forward step."""
    timers = get_timers()
    assert input_tensor is None

    # Get the batch.
    timers("batch-generator").start()
@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor):
    timers("batch-generator").stop()

    # Forward model. lm_labels
    logits = model(images).contiguous().float()
    loss = F.cross_entropy(logits, labels)

    outputs = torch.argmax(logits, -1)
    correct = (outputs == labels).float()
    accuracy = torch.mean(correct)

    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])

    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
    output_tensor = model(images)

    return output_tensor, partial(loss_func, labels)

def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
+3 −2
Original line number Diff line number Diff line
@@ -34,13 +34,14 @@ def classification():
        )
        return train_ds, valid_ds

    def model_provider():
    def model_provider(pre_process=True, post_process=True):
        """Build the model."""
        args = get_args()

        print_rank_0("building classification model for ImageNet ...")

        return VitModel(num_classes=args.num_classes, finetune=True)
        return VitModel(num_classes=args.num_classes, finetune=True,
                        pre_process=pre_process, post_process=post_process)

    """Finetune/evaluate."""
    finetune(
+56 −20
Original line number Diff line number Diff line
@@ -16,10 +16,14 @@
"""Evaluation utilities."""

import os
from functools import partial

import torch

from megatron import get_args
from megatron import print_rank_0
from megatron import print_rank_0, print_rank_last
from megatron import mpu
from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
from torchvision import datasets, transforms
@@ -56,7 +60,7 @@ def accuracy_func_provider():
        print_rank_0("calculating metrics ...")
        correct, total = calculate_correct_answers(model, dataloader, epoch)
        percent = float(correct) * 100.0 / float(total)
        print_rank_0(
        print_rank_last(
            " >> |epoch: {}| overall: correct / total = {} / {} = "
            "{:.4f} %".format(epoch, correct, total, percent)
        )
@@ -67,27 +71,59 @@ def accuracy_func_provider():
def calculate_correct_answers(model, dataloader, epoch):
    """Calculate correct over total answers"""

    model.eval()
    args = get_args()
    forward_backward_func = get_forward_backward_func()
    for m in model:
        m.eval()

    def loss_func(labels, output_tensor):
        logits = output_tensor

        loss_dict = {}
        # Compute the correct answers.
        predicted = torch.argmax(logits, dim=-1)
        corrects = (predicted == labels).float()
        # Add to the counters.
        loss_dict['total'] = labels.size(0)
        loss_dict['correct'] = corrects.sum().item()

        return 0, loss_dict

    #defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        images, labels = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(images)

        return output_tensor, partial(loss_func, labels)

    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = 0
        for _, batch in enumerate(dataloader):
            # Run the model forward.
            images, labels = process_batch(batch)
            logits = model(images).contiguous().float()
            # Add output predictions.
            # Compute the correct answers.
            predicted = torch.argmax(logits, dim=-1)
            corrects = (predicted == labels).float()
            # Add to the counters.
            total += labels.size(0)
            correct += corrects.sum().item()
    model.train()

            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
                                               optimizer=None, timers=None, forward_only=True)

            for loss_dict in loss_dicts:
                total += loss_dict['total']
                correct += loss_dict['correct']

    for m in model:
        m.train()

    # Reduce.
    if mpu.is_pipeline_last_stage():
        unreduced = torch.cuda.LongTensor([correct, total])
    torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group())
        torch.distributed.all_reduce(unreduced,
                                     group=mpu.get_data_parallel_group())

        # Print on screen.
        correct_ans = unreduced[0].item()
Loading