Commit 90ef2e28 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Raul Puri
Browse files

BERT downstream tasks: RACE, MNLI, QQP

parent 41141526
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -346,10 +346,10 @@ def add_data_args(parser):
                       help='path used to save/load sentencepiece tokenization '
                       'models')
    group.add_argument('--tokenizer-type', type=str,
                       default='BertWordPieceTokenizer',
                       default='BertWordPieceLowerCase',
                       choices=['CharacterLevelTokenizer',
                                'SentencePieceTokenizer',
                                'BertWordPieceTokenizer',
                                'BertWordPieceLowerCase',
                                'GPT2BPETokenizer'],
                       help='what type of tokenizer to use')
    group.add_argument("--cache-dir", default=None, type=str,
@@ -358,7 +358,7 @@ def add_data_args(parser):
    return parser


def get_args():
def get_args(extra_args_provider=None):
    """Parse all the args."""

    parser = argparse.ArgumentParser(description='PyTorch BERT Model')
@@ -368,6 +368,8 @@ def get_args():
    parser = add_evaluation_args(parser)
    parser = add_text_generate_args(parser)
    parser = add_data_args(parser)
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)

    args = parser.parse_args()

+100 −0
Original line number Diff line number Diff line

"""Megatron tokenizer."""


from abc import ABC
from abc import abstractmethod

from megatron.utils import vocab_size_with_padding
from .bert_tokenization import FullTokenizer as FullBertTokenizer


def add_tokenizer_to_args(args, tokenizer_type):
    """Instantiate tokenizer based on input type and add it to args."""

    # Make sure we have not already called this method.
    if hasattr(args, 'tokenizer'):
        raise Exception('args already has a tokenizer')
    # Select and instantiate the tokenizer.
    if tokenizer_type == 'BertWordPieceLowerCase':
        args.tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab,
                                                 lower_case=True)
    else:
        raise NotImplementedError('{} tokenizer is not '
                                  'implemented.'.format(tokenizer_type))

    # Add vocab size.
    args.vocab_size = vocab_size_with_padding(args.tokenizer.vocab_size, args)


class AbstractTokenizer(ABC):
    """Abstract class for tokenizer."""

    def __init__(self, name):
        self.name = name
        super().__init__()

    @property
    @abstractmethod
    def vocab_size(self):
        pass

    @abstractmethod
    def tokenize(self, text):
        pass

    @property
    def cls(self):
        raise NotImplementedError('CLS is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def sep(self):
        raise NotImplementedError('SEP is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def pad(self):
        raise NotImplementedError('PAD is not provided for {} '
                                  'tokenizer'.format(self.name))

    @property
    def eod(self):
        raise NotImplementedError('EOD is not provided for {} '
                                  'tokenizer'.format(self.name))



class _BertWordPieceTokenizer(AbstractTokenizer):
    """Original BERT wordpiece tokenizer."""

    def __init__(self, vocab_file, lower_case=True):
        if lower_case:
            name = 'BERT Lower Case'
        else:
            name = 'BERT Upper Case'
        super().__init__(name)
        self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
        self.cls_id = self.tokenizer.vocab['[CLS]']
        self.sep_id = self.tokenizer.vocab['[SEP]']
        self.pad_id = self.tokenizer.vocab['[PAD]']

    @property
    def vocab_size(self):
        return self.tokenizer.vocab_size()

    def tokenize(self, text):
        text_tokens = self.tokenizer.tokenize(text)
        return self.tokenizer.convert_tokens_to_ids(text_tokens)

    @property
    def cls(self):
        return self.cls_id

    @property
    def sep(self):
        return self.sep_id

    @property
    def pad(self):
        return self.pad_id
+133 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classification model."""

import torch

from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron.utils import print_rank_0


class Classification(MegatronModule):

    def __init__(self,
                 num_classes,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 max_sequence_length,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=2,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

        super(Classification, self).__init__()

        self.num_classes = num_classes
        init_method = init_method_normal(init_method_std)

        self.language_model, self._language_model_key = get_language_model(
            num_layers=num_layers,
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            embedding_dropout_prob=embedding_dropout_prob,
            attention_dropout_prob=attention_dropout_prob,
            output_dropout_prob=output_dropout_prob,
            max_sequence_length=max_sequence_length,
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            attention_mask_func=bert_attention_mask_func,
            checkpoint_activations=checkpoint_activations,
            checkpoint_num_layers=checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
                        residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)

        # Multi-choice head.
        self.classification_dropout = torch.nn.Dropout(output_dropout_prob)
        self.classification_head = get_linear_layer(hidden_size,
                                                    self.num_classes,
                                                    init_method)
        self._classification_head_key = 'classification_head'


    def forward(self, input_ids, attention_mask, tokentype_ids):

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
                                               position_ids,
                                               extended_attention_mask,
                                               tokentype_ids=tokentype_ids)

        # Output.
        classification_output = self.classification_dropout(pooled_output)
        classification_logits = self.classification_head(classification_output)

        # Reshape back to separate choices.
        classification_logits = classification_logits.view(-1, self.num_classes)

        return classification_logits


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        state_dict_[self._classification_head_key] \
            = self.classification_head.state_dict(
                destination, prefix, keep_vars)
        return state_dict_


    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if self._classification_head_key in state_dict:
            self.classification_head.load_state_dict(
                state_dict[self._classification_head_key], strict=strict)
        else:
            print_rank_0('***WARNING*** could not find {} in the checkpoint, '
                         'initializing to random'.format(
                             self._classification_head_key))
+143 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multiple choice model."""

import torch

from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron.utils import print_rank_0


class MultipleChoice(MegatronModule):

    def __init__(self,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 max_sequence_length,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=2,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

        super(MultipleChoice, self).__init__()

        init_method = init_method_normal(init_method_std)

        self.language_model, self._language_model_key = get_language_model(
            num_layers=num_layers,
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            embedding_dropout_prob=embedding_dropout_prob,
            attention_dropout_prob=attention_dropout_prob,
            output_dropout_prob=output_dropout_prob,
            max_sequence_length=max_sequence_length,
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            attention_mask_func=bert_attention_mask_func,
            checkpoint_activations=checkpoint_activations,
            checkpoint_num_layers=checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
            residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)

        # Multi-choice head.
        self.multichoice_dropout = torch.nn.Dropout(output_dropout_prob)
        self.multichoice_head = get_linear_layer(hidden_size, 1, init_method)
        self._multichoice_head_key = 'multichoice_head'


    def forward(self, input_ids, attention_mask, tokentype_ids):

        # [batch, choices, sequence] --> [batch * choices, sequence] -->
        #    transformer --> [batch, choices] --> softmax

        # Ensure the shape is [batch-size, choices, sequence]
        assert len(input_ids.shape) == 3
        assert len(attention_mask.shape) == 3
        assert len(tokentype_ids.shape) == 3

        # Reshape and treat choice dimension the same as batch.
        num_choices = input_ids.shape[1]
        input_ids = input_ids.view(-1, input_ids.size(-1))
        attention_mask = attention_mask.view(-1, attention_mask.size(-1))
        tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
        position_ids = bert_position_ids(input_ids)

        _, pooled_output = self.language_model(input_ids,
                                               position_ids,
                                               extended_attention_mask,
                                               tokentype_ids=tokentype_ids)

        # Output.
        multichoice_output = self.multichoice_dropout(pooled_output)
        multichoice_logits = self.multichoice_head(multichoice_output)

        # Reshape back to separate choices.
        multichoice_logits = multichoice_logits.view(-1, num_choices)

        return multichoice_logits


    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        state_dict_[self._multichoice_head_key] \
            = self.multichoice_head.state_dict(
                destination, prefix, keep_vars)
        return state_dict_


    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if self._multichoice_head_key in state_dict:
            self.multichoice_head.load_state_dict(
                state_dict[self._multichoice_head_key], strict=strict)
        else:
            print_rank_0('***WARNING*** could not find {} in the checkpoint, '
                         'initializing to random'.format(
                             self._multichoice_head_key))
+13 −13
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ from megatron.utils import Timers


def run(top_level_message, train_val_test_data_provider,
        model_provider, forward_step_func):
        model_provider, forward_step_func, extra_args_provider=None):
    """Main training program.

    This function will run the followings in the order provided:
@@ -71,17 +71,9 @@ def run(top_level_message, train_val_test_data_provider,
            function add `batch generator` to the timers class.
    """

    # Arguments.
    args = get_args()

    # Timer.
    timers = Timers()

    # Tensorboard writer
    writer = get_tensorboard_writer(args)

    # Initalize.
    initialize_megatron(top_level_message, args, writer)
    # Initalize and get arguments, timers, and Tensorboard writer.
    args = get_args(extra_args_provider=extra_args_provider)
    timers, writer = initialize_megatron(top_level_message, args)

    # Data stuff.
    train_data, val_data, test_data = train_val_test_data_provider(args)
@@ -124,9 +116,15 @@ def run(top_level_message, train_val_test_data_provider,
                                   args, None, 0, timers, True)


def initialize_megatron(message, args, writer):
def initialize_megatron(message, args):
    """"Initialize distributed, random seed, and autoresume."""

    # Timer.
    timers = Timers()

    # Tensorboard writer.
    writer = get_tensorboard_writer(args)

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
@@ -141,6 +139,8 @@ def initialize_megatron(message, args, writer):
    # Random seeds for reproducability.
    set_random_seed(args.seed)

    return timers, writer


def get_model(model_provider_func, args):
    """Build the model."""
Loading