Commit 84eb016c authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main' into main_retriver_merge_dpr

parents c7c65bbb 83c4d95a
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -60,8 +60,8 @@ def check_checkpoint_args(checkpoint_args):
    _compare('num_layers')
    _compare('hidden_size')
    _compare('num_attention_heads')
    _compare('max_position_embeddings')
    if args.vocab_file:
        _compare('max_position_embeddings')
        _compare('make_vocab_size_divisible_by')
        _compare('padded_vocab_size')
        _compare('tokenizer_type')
+76 −52
Original line number Diff line number Diff line
@@ -50,11 +50,11 @@ class VitMlpHead(MegatronModule):
    def forward(self, hidden_states, sequence_index=0):
        # hidden_states: [b, s, h]
        # sequence_index: index of the token to pool.
        x = hidden_states[:, sequence_index, :]
        x = self.dense_in(x)
        x = torch.tanh(x)
        x = self.dense_out(x)
        return x
        hidden_state = hidden_states[:, sequence_index, :]
        dense_in_result = self.dense_in(hidden_state)
        tanh_result = torch.tanh(dense_in_result)
        dense_out_result = self.dense_out(tanh_result)
        return dense_out_result


def twod_interpolate_position_embeddings_hook(
@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook(
class VitModel(MegatronModule):
    """Vision Transformer Model."""

    def __init__(self, num_classes, finetune=False):
        super(VitModel, self).__init__()
    def __init__(self, 
                 num_classes,
                 finetune=False,
                 pre_process=True,
                 post_process=True):
        super(VitModel, self).__init__(share_word_embeddings=False)
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
@@ -136,6 +140,8 @@ class VitModel(MegatronModule):
                args.init_method_std, args.num_layers
            )

        self.pre_process = pre_process
        self.post_process = post_process
        self.hidden_size = args.hidden_size
        self.num_classes = num_classes
        self.patch_dim = args.patch_dim
@@ -148,8 +154,11 @@ class VitModel(MegatronModule):
        self.seq_length = self.num_patches + 1
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels

        if self.pre_process:
            # cls_token
        self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
            self.cls_token = torch.nn.Parameter(
                torch.randn(1, 1, self.hidden_size)
            )
            torch.nn.init.zeros_(self.cls_token)

            # Linear encoder
@@ -174,9 +183,13 @@ class VitModel(MegatronModule):

        # Transformer
        self.transformer = ParallelTransformer(
            self.init_method, self.scaled_init_method
            self.init_method, 
            self.scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process
        )

        if self.post_process:
            # MLP head
            if not self.finetune:
                self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
@@ -185,26 +198,37 @@ class VitModel(MegatronModule):
                    self.hidden_size, num_classes, torch.nn.init.zeros_
                )

    def forward(self, x):
        x = einops.rearrange(
            x,
    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.transformer.set_input_tensor(input_tensor)

    def forward(self, input):

        if self.pre_process:
            rearranged_input = einops.rearrange(
                input,
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=self.patch_dim,
                p2=self.patch_dim,
            )

        assert x.dtype == torch.half
        x = self.linear_encoder(x)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
            assert rearranged_input.dtype == torch.half
            encoder_output = self.linear_encoder(rearranged_input)
            cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
            concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)

            token_embeddings = concatenated_tokens + \
                self.position_embeddings(self.position_ids)
            hidden_states = self.embedding_dropout(token_embeddings)
        else:
            hidden_states = input

        x = x + self.position_embeddings(self.position_ids)
        x = self.embedding_dropout(x)
        x = self.transformer(x, None)
        hidden_states = self.transformer(hidden_states, None)

        if self.post_process:
            if not self.finetune:
            x = self.mlp_head(x)
                hidden_states = self.mlp_head(hidden_states)
            else:
            x = self.class_head(x[:, 0, :])
                hidden_states = self.class_head(hidden_states[:, 0, :])

        return x
        return hidden_states
+20 −14
Original line number Diff line number Diff line
@@ -17,19 +17,22 @@

import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vit_model import VitModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group

def model_provider():
def model_provider(pre_process=True, post_process=True):
    """Build the model."""

    print_rank_0("building VIT model ...")
    args = get_args()

    model = VitModel(num_classes=args.num_classes)
    model = VitModel(num_classes=args.num_classes,
                     pre_process=pre_process,
                     post_process=post_process)
    return model

def get_batch(data_iterator):
@@ -42,10 +45,21 @@ def get_batch(data_iterator):

    return images, labels

def forward_step(data_iterator, model, input_tensor):
def loss_func(labels, output_tensor):
    logits = output_tensor.contiguous().float()
    loss = F.cross_entropy(logits, labels)

    outputs = torch.argmax(logits, -1)
    correct = (outputs == labels).float()
    accuracy = torch.mean(correct)

    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])

    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}

def forward_step(data_iterator, model):
    """Forward step."""
    timers = get_timers()
    assert input_tensor is None

    # Get the batch.
    timers("batch-generator").start()
@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor):
    timers("batch-generator").stop()

    # Forward model. lm_labels
    logits = model(images).contiguous().float()
    loss = F.cross_entropy(logits, labels)

    outputs = torch.argmax(logits, -1)
    correct = (outputs == labels).float()
    accuracy = torch.mean(correct)

    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])

    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
    output_tensor = model(images)

    return output_tensor, partial(loss_func, labels)

def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
+3 −1
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ from functools import partial
import sys
import torch

from megatron import get_args
from megatron import get_args, get_num_microbatches
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
@@ -159,6 +159,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
    args = get_args()
    timers = get_timers()

    assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"

    # Turn on training mode which enables dropout.
    for m in model:
        m.train()
+3 −2
Original line number Diff line number Diff line
@@ -34,13 +34,14 @@ def classification():
        )
        return train_ds, valid_ds

    def model_provider():
    def model_provider(pre_process=True, post_process=True):
        """Build the model."""
        args = get_args()

        print_rank_0("building classification model for ImageNet ...")

        return VitModel(num_classes=args.num_classes, finetune=True)
        return VitModel(num_classes=args.num_classes, finetune=True,
                        pre_process=pre_process, post_process=post_process)

    """Finetune/evaluate."""
    finetune(
Loading