Commit 691747b1 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

added query-key layer scaling and softmax fp32 option

parent 7a3b4c15
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -70,6 +70,12 @@ def add_fp16_config_args(parser):

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode')
    group.add_argument('--apply-query-key-layer-scaling', action='store_true',
                       help='Scale Q * K^T by 1 / layer-number. If this flag '
                       'is set, then it will automatically set '
                       'attention-softmax-in-fp32 to true')
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
                       help='Run attention masking and softmax in fp32.')
    group.add_argument('--fp32-embedding', action='store_true',
                       help='embedding in fp32')
    group.add_argument('--fp32-layernorm', action='store_true',
+6 −2
Original line number Diff line number Diff line
@@ -119,7 +119,9 @@ class BertModel(MegatronModule):
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=0,
                 parallel_output=True):
                 parallel_output=True,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

        super(BertModel, self).__init__()

@@ -145,7 +147,9 @@ class BertModel(MegatronModule):
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
            residual_connection_post_layernorm=False)
            residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)

        self.lm_head = BertLMHead(
            self.language_model.embedding.word_embeddings.weight.size(0),
+6 −2
Original line number Diff line number Diff line
@@ -48,7 +48,9 @@ class GPT2Model(MegatronModule):
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=0,
                 parallel_output=True):
                 parallel_output=True,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

        super(GPT2Model, self).__init__()

@@ -72,7 +74,9 @@ class GPT2Model(MegatronModule):
            init_method=init_method_normal(init_method_std),
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
            residual_connection_post_layernorm=False)
            residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)


    def forward(self, input_ids, position_ids, attention_mask,
+6 −2
Original line number Diff line number Diff line
@@ -60,7 +60,9 @@ def get_language_model(num_layers,
                       layernorm_epsilon,
                       init_method,
                       scaled_init_method,
                       residual_connection_post_layernorm):
                       residual_connection_post_layernorm,
                       apply_query_key_layer_scaling,
                       attention_softmax_in_fp32):
    # Transformer hyperparameters.
    transformer_hparams = TransformerHyperparameters(
        hidden_size=hidden_size,
@@ -74,7 +76,9 @@ def get_language_model(num_layers,
        output_layer_init_method=scaled_init_method,
        checkpoint_activations=checkpoint_activations,
        checkpoint_num_layers=checkpoint_num_layers,
        apply_residual_connection_post_layernorm=residual_connection_post_layernorm)
        apply_residual_connection_post_layernorm=residual_connection_post_layernorm,
        apply_query_key_layer_scaling=apply_query_key_layer_scaling,
        attention_softmax_in_fp32=attention_softmax_in_fp32)
    # Language model.
    language_model = TransformerLanguageModel(
        transformer_hparams=transformer_hparams,
+36 −11
Original line number Diff line number Diff line
@@ -82,7 +82,9 @@ class TransformerHyperparameters:
                 output_layer_init_method=None,
                 checkpoint_activations=None,
                 checkpoint_num_layers=None,
                 apply_residual_connection_post_layernorm=None):
                 apply_residual_connection_post_layernorm=None,
                 apply_query_key_layer_scaling=None,
                 attention_softmax_in_fp32=None):
        self.params_dict = {}
        self.params_dict['hidden_size'] = hidden_size
        self.params_dict['num_layers'] = num_layers
@@ -97,6 +99,10 @@ class TransformerHyperparameters:
        self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers
        self.params_dict['apply_residual_connection_post_layernorm'] \
            = apply_residual_connection_post_layernorm
        self.params_dict['apply_query_key_layer_scaling'] \
            = apply_query_key_layer_scaling
        self.params_dict['attention_softmax_in_fp32'] \
            = attention_softmax_in_fp32


    def __getitem__(self, key):
@@ -169,10 +175,17 @@ class ParallelSelfAttention(MegatronModule):
    and returns output of the same size.
    """

    def __init__(self, hyperparameters, attention_mask_func):
    def __init__(self, hyperparameters, attention_mask_func, layer_number):
        super(ParallelSelfAttention, self).__init__()

        self.attention_mask_func = attention_mask_func
        self.apply_query_key_layer_scaling \
            = hyperparameters['apply_query_key_layer_scaling']
        self.attention_softmax_in_fp32 \
            = hyperparameters['attention_softmax_in_fp32']
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
@@ -239,7 +252,11 @@ class ParallelSelfAttention(MegatronModule):

    def _get_unmasked_attention_scores(self, query_layer, key_layer):
        """Unmasked attention scores with size [b, np, s, s]."""
        norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
        coeff = 1
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
        norm_factor = math.sqrt(coeff *
                                math.sqrt(self.hidden_size_per_attention_head))
        # Raw attention scores. [b, np, s, s]
        return torch.matmul(query_layer/norm_factor,
                            key_layer.transpose(-1, -2)/norm_factor)
@@ -250,7 +267,9 @@ class ParallelSelfAttention(MegatronModule):
        the size [b, np, s, s].
        """
        # Attention probabilities. [b, np, s, s]
        attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
        if self.apply_query_key_layer_scaling:
            attention_scores = attention_scores * self.layer_number
        attention_probs = torch.nn.Softmax(dim=-1)(attention_probs)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
@@ -304,6 +323,10 @@ class ParallelSelfAttention(MegatronModule):
        attention_scores = self._get_unmasked_attention_scores(
            query_layer, key_layer)

        # fp32 conversion.
        if self.attention_softmax_in_fp32:
            attention_scores = attention_scores.float()

        # Apply attention mask. [b, np, s, s]
        if get_key_value:
            with torch.no_grad():
@@ -323,6 +346,10 @@ class ParallelSelfAttention(MegatronModule):
        # Attention probabilities. [b, np, s, s]
        attention_probs = self._get_attention_probs(attention_scores)

        # fp16 conversion
        if self.attention_softmax_in_fp32:
            attention_probs = attention_probs.half()

        # Context layer. [b, s, hp]
        context_layer = self._get_attended_context(attention_probs, value_layer)

@@ -342,7 +369,7 @@ class ParallelTransformerLayer(MegatronModule):
    Transformore layer takes input with size [b, s, h] and returns an
    output of the same size.
    """
    def __init__(self, hyperparameters, attention_mask_func):
    def __init__(self, hyperparameters, attention_mask_func, layer_number):

        super(ParallelTransformerLayer, self).__init__()

@@ -356,8 +383,7 @@ class ParallelTransformerLayer(MegatronModule):

        # Self attention.
        self.attention = ParallelSelfAttention(
            hyperparameters,
            attention_mask_func)
            hyperparameters, attention_mask_func, layer_number)

        # Layernorm on the input data.
        self.post_attention_layernorm = LayerNorm(
@@ -414,14 +440,13 @@ class ParallelTransformer(MegatronModule):
        self.checkpoint_activations = hyperparameters['checkpoint_activations']
        self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers']

        def get_layer():
        def get_layer(layer_number):
            return ParallelTransformerLayer(
                hyperparameters,
                attention_mask_func)
                hyperparameters, attention_mask_func, layer_number)

        # Transformer layers.
        self.layers = torch.nn.ModuleList(
            [get_layer() for _ in range(hyperparameters['num_layers'])])
            [get_layer(i+1) for i in range(hyperparameters['num_layers'])])

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(
Loading