Commit deffcb6a authored by Mohammad's avatar Mohammad
Browse files

arguments in the model refactored

parent 601b19b7
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -108,6 +108,10 @@ def _add_network_size_args(parser):
                       'This is added for computational efficieny reasons.')
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')

    return parser

+13 −39
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@

import torch

from megatron import get_args
from megatron.module import MegatronModule

from .language_model import parallel_lm_logits
@@ -106,60 +107,33 @@ class BertLMHead(MegatronModule):
class BertModel(MegatronModule):
    """Bert Language model."""

    def __init__(self,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 max_sequence_length,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 add_binary_head=False,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=0,
                 parallel_output=True,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

    def __init__(self, num_tokentypes=2, add_binary_head=True,
                 parallel_output=True):
        super(BertModel, self).__init__()
        args = get_args()

        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
        init_method = init_method_normal(init_method_std)
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_layers=num_layers,
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            embedding_dropout_prob=embedding_dropout_prob,
            attention_dropout_prob=attention_dropout_prob,
            output_dropout_prob=output_dropout_prob,
            max_sequence_length=max_sequence_length,
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=self.add_binary_head,
            attention_mask_func=bert_attention_mask_func,
            checkpoint_activations=checkpoint_activations,
            checkpoint_num_layers=checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
            residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)
            scaled_init_method=scaled_init_method)

        self.lm_head = BertLMHead(
            self.language_model.embedding.word_embeddings.weight.size(0),
            hidden_size, init_method, layernorm_epsilon, parallel_output)
            args.hidden_size, init_method, args.layernorm_epsilon,
            parallel_output)
        self._lm_head_key = 'lm_head'

        if self.add_binary_head:
            self.binary_head = get_linear_layer(hidden_size, 2, init_method)
            self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                init_method)
            self._binary_head_key = 'binary_head'


+9 −38
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@

import torch

from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
@@ -30,54 +31,24 @@ from megatron import print_rank_0

class Classification(MegatronModule):

    def __init__(self,
                 num_classes,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 max_sequence_length,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=2,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

    def __init__(self, num_classes, num_tokentypes=2):
        super(Classification, self).__init__()
        args = get_args()

        self.num_classes = num_classes
        init_method = init_method_normal(init_method_std)
        init_method = init_method_normal(args.init_method_std)

        self.language_model, self._language_model_key = get_language_model(
            num_layers=num_layers,
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            embedding_dropout_prob=embedding_dropout_prob,
            attention_dropout_prob=attention_dropout_prob,
            output_dropout_prob=output_dropout_prob,
            max_sequence_length=max_sequence_length,
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=True,
            attention_mask_func=bert_attention_mask_func,
            checkpoint_activations=checkpoint_activations,
            checkpoint_num_layers=checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
                        residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))

        # Multi-choice head.
        self.classification_dropout = torch.nn.Dropout(output_dropout_prob)
        self.classification_head = get_linear_layer(hidden_size,
        self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
        self.classification_head = get_linear_layer(args.hidden_size,
                                                    self.num_classes,
                                                    init_method)
        self._classification_head_key = 'classification_head'
+7 −36
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@

import torch

from megatron import get_args
from megatron.module import MegatronModule

from .language_model import parallel_lm_logits
@@ -34,49 +35,19 @@ def gpt2_attention_mask_func(attention_scores, ltor_mask):
class GPT2Model(MegatronModule):
    """GPT-2 Language model."""

    def __init__(self,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 max_sequence_length,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=0,
                 parallel_output=True,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

    def __init__(self, num_tokentypes=0, parallel_output=True):
        super(GPT2Model, self).__init__()
        args = get_args()

        self.parallel_output = parallel_output

        self.language_model, self._language_model_key = get_language_model(
            num_layers=num_layers,
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            embedding_dropout_prob=embedding_dropout_prob,
            attention_dropout_prob=attention_dropout_prob,
            output_dropout_prob=output_dropout_prob,
            max_sequence_length=max_sequence_length,
            attention_mask_func=gpt2_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=False,
            attention_mask_func=gpt2_attention_mask_func,
            checkpoint_activations=checkpoint_activations,
            checkpoint_num_layers=checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            init_method=init_method_normal(init_method_std),
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
            residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)
            init_method=init_method_normal(args.init_method_std),
            scaled_init_method=scaled_init_method_normal(args.init_method_std,
                                                         args.num_layers))


    def forward(self, input_ids, position_ids, attention_mask,
+26 −58
Original line number Diff line number Diff line
@@ -18,13 +18,13 @@
import torch
import torch.nn.functional as F

from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule

from .transformer import ParallelTransformer
from .transformer import TransformerHyperparameters
from .utils import gelu
from .utils import get_linear_layer
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import gelu
from megatron.model.utils import get_linear_layer


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
@@ -40,52 +40,20 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
    # Gather if needed.
    if parallel_output:
        return logits_parallel
    else:

    return mpu.gather_from_model_parallel_region(logits_parallel)


def get_language_model(num_layers,
                       vocab_size,
                       hidden_size,
                       num_attention_heads,
                       embedding_dropout_prob,
                       attention_dropout_prob,
                       output_dropout_prob,
                       max_sequence_length,
                       num_tokentypes,
                       attention_mask_func,
                       add_pooler,
                       checkpoint_activations,
                       checkpoint_num_layers,
                       layernorm_epsilon,
                       init_method,
                       scaled_init_method,
                       residual_connection_post_layernorm,
                       apply_query_key_layer_scaling,
                       attention_softmax_in_fp32):
    # Transformer hyperparameters.
    transformer_hparams = TransformerHyperparameters(
        hidden_size=hidden_size,
        num_layers=num_layers,
        num_attention_heads=num_attention_heads,
        attention_dropout_prob=attention_dropout_prob,
        output_dropout_prob=output_dropout_prob,
        mlp_activation_func=gelu,
        layernorm_epsilon=layernorm_epsilon,
        init_method=init_method,
        output_layer_init_method=scaled_init_method,
        checkpoint_activations=checkpoint_activations,
        checkpoint_num_layers=checkpoint_num_layers,
        apply_residual_connection_post_layernorm=residual_connection_post_layernorm,
        apply_query_key_layer_scaling=apply_query_key_layer_scaling,
        attention_softmax_in_fp32=attention_softmax_in_fp32)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
                       init_method, scaled_init_method):
    """Build language model and return along with the key to save."""

    # Language model.
    language_model = TransformerLanguageModel(
        transformer_hparams=transformer_hparams,
        attention_mask_func=attention_mask_func,
        vocab_size=vocab_size,
        max_sequence_length=max_sequence_length,
        embedding_dropout_prob=embedding_dropout_prob,
        mlp_activation_func=gelu,
        init_method=init_method,
        output_layer_init_method=scaled_init_method,
        num_tokentypes=num_tokentypes,
        add_pooler=add_pooler)
    # key used for checkpoints.
@@ -293,33 +261,33 @@ class TransformerLanguageModel(MegatronModule):
                        will ignore this embedding
    """
    def __init__(self,
                 transformer_hparams,
                 attention_mask_func,
                 vocab_size,
                 max_sequence_length,
                 embedding_dropout_prob,
                 mlp_activation_func,
                 init_method,
                 output_layer_init_method,
                 num_tokentypes=0,
                 add_pooler=False):
        super(TransformerLanguageModel, self).__init__()
        args = get_args()

        self.hidden_size = transformer_hparams['hidden_size']
        self.hidden_size = args.hidden_size
        self.num_tokentypes = num_tokentypes
        self.init_method = transformer_hparams['init_method']
        self.init_method = init_method
        self.add_pooler = add_pooler

        # Embeddings
        self.embedding = Embedding(self.hidden_size,
                                   vocab_size,
                                   max_sequence_length,
                                   embedding_dropout_prob,
                                   args.padded_vocab_size,
                                   args.max_position_embeddings,
                                   args.hidden_dropout,
                                   self.init_method,
                                   self.num_tokentypes)
        self._embedding_key = 'embedding'

        # Transformer
        self.transformer = ParallelTransformer(
            transformer_hparams,
            attention_mask_func)
            attention_mask_func, mlp_activation_func,
            self.init_method, output_layer_init_method)
        self._transformer_key = 'transformer'

        # Pooler
Loading