Commit a4cb4153 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'scale_qkt_exp_lr' into 'master'

Scale Q*K (query times key) by 1/layer-number and add exponential decay option

See merge request ADLR/megatron-lm!27
parents 4ad5717e 6c2d0337
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -70,6 +70,12 @@ def add_fp16_config_args(parser):

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode')
    group.add_argument('--apply-query-key-layer-scaling', action='store_true',
                       help='Scale Q * K^T by 1 / layer-number. If this flag '
                       'is set, then it will automatically set '
                       'attention-softmax-in-fp32 to true')
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
                       help='Run attention masking and softmax in fp32.')
    group.add_argument('--fp32-embedding', action='store_true',
                       help='embedding in fp32')
    group.add_argument('--fp32-layernorm', action='store_true',
+4 −1
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0
class AnnealingLR(_LRScheduler):
    """Anneals the learning rate"""

    DECAY_STYLES = ['linear', 'cosine', 'constant', 'None']
    DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']

    def __init__(self, optimizer, start_lr, warmup_iter, num_iters,
                 decay_style=None, last_iter=-1, min_lr=0.0,
@@ -57,6 +57,9 @@ class AnnealingLR(_LRScheduler):
                lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter)
            elif self.decay_style == self.DECAY_STYLES[1]:
                lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
            elif self.decay_style == self.DECAY_STYLES[2]:
                # exp(-0.693) = 1/2
                lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter)
            else:
                lr = self.start_lr
            return max(lr, self.min_lr)
+6 −2
Original line number Diff line number Diff line
@@ -119,7 +119,9 @@ class BertModel(MegatronModule):
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=0,
                 parallel_output=True):
                 parallel_output=True,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

        super(BertModel, self).__init__()

@@ -145,7 +147,9 @@ class BertModel(MegatronModule):
            init_method=init_method,
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
            residual_connection_post_layernorm=False)
            residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)

        self.lm_head = BertLMHead(
            self.language_model.embedding.word_embeddings.weight.size(0),
+6 −2
Original line number Diff line number Diff line
@@ -48,7 +48,9 @@ class GPT2Model(MegatronModule):
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 num_tokentypes=0,
                 parallel_output=True):
                 parallel_output=True,
                 apply_query_key_layer_scaling=False,
                 attention_softmax_in_fp32=False):

        super(GPT2Model, self).__init__()

@@ -72,7 +74,9 @@ class GPT2Model(MegatronModule):
            init_method=init_method_normal(init_method_std),
            scaled_init_method=scaled_init_method_normal(init_method_std,
                                                         num_layers),
            residual_connection_post_layernorm=False)
            residual_connection_post_layernorm=False,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            attention_softmax_in_fp32=attention_softmax_in_fp32)


    def forward(self, input_ids, position_ids, attention_mask,
+6 −2
Original line number Diff line number Diff line
@@ -60,7 +60,9 @@ def get_language_model(num_layers,
                       layernorm_epsilon,
                       init_method,
                       scaled_init_method,
                       residual_connection_post_layernorm):
                       residual_connection_post_layernorm,
                       apply_query_key_layer_scaling,
                       attention_softmax_in_fp32):
    # Transformer hyperparameters.
    transformer_hparams = TransformerHyperparameters(
        hidden_size=hidden_size,
@@ -74,7 +76,9 @@ def get_language_model(num_layers,
        output_layer_init_method=scaled_init_method,
        checkpoint_activations=checkpoint_activations,
        checkpoint_num_layers=checkpoint_num_layers,
        apply_residual_connection_post_layernorm=residual_connection_post_layernorm)
        apply_residual_connection_post_layernorm=residual_connection_post_layernorm,
        apply_query_key_layer_scaling=apply_query_key_layer_scaling,
        attention_softmax_in_fp32=attention_softmax_in_fp32)
    # Language model.
    language_model = TransformerLanguageModel(
        transformer_hparams=transformer_hparams,
Loading