Commit 63262827 authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'staging_text_generation' into 'staging'

Refactoring text generation

See merge request ADLR/megatron-lm!39
parents 3977b721 fffa0497
Loading
Loading
Loading
Loading
+95 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sample Generate GPT2"""

from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
from megatron.text_generation_utils import generate_samples_interactive


def model_provider():
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(num_tokentypes=0, parallel_output=False)

    return model


def add_text_generate_args(parser):
    """Text generation arguments."""
    group = parser.add_argument_group(title='text generation')

    group.add_argument("--temperature", type=float, default=1.0,
                       help='Sampling temperature.')
    group.add_argument("--greedy", action='store_true', default=False,
                       help='Use greedy sampling.')
    group.add_argument("--top_p", type=float, default=0.0,
                       help='Top p sampling.')
    group.add_argument("--top_k", type=int, default=0,
                       help='Top k sampling.')
    group.add_argument("--out-seq-length", type=int, default=1024,
                       help='Size of the output generated text.')
    group.add_argument("--sample-input-file", type=str, default=None,
                       help='Get input from file instead of interactive mode, '
                       'each line is an input.')
    group.add_argument("--sample-output-file", type=str, default=None,
                       help='Output file got from --sample-input-file')
    group.add_argument("--num-samples", type=int, default=0,
                       help='Number of samples to generate unconditionally, '
                       'defaults to 0 and interactive conditional sampling')
    group.add_argument("--genfile", type=str,
                       help='Output file when generating unconditionally')
    group.add_argument("--recompute", action='store_true',
                       help='During generation recompute all attention '
                       'instead of using previously computed keys/values.')

    return parser


def main():
    """Main program."""

    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})

    # Set up model and load checkpoint.
    model = get_model(model_provider)
    args = get_args()
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    # Generate samples.
    if args.num_samples == 0:
        args.batch_size = 1
        if args.sample_input_file != "":
            generate_samples_input_from_file(model)
        else:
            generate_samples_interactive(model)
    else:
        generate_and_write_samples_unconditional(model)


if __name__ == "__main__":

    main()
+6 −30
Original line number Diff line number Diff line
@@ -69,7 +69,9 @@ def parse_args(extra_args_provider=None, defaults={}):

    # Checks.
    assert args.hidden_size % args.num_attention_heads == 0
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
    if args.lr is not None:
        assert args.min_lr <= args.lr
    if args.save is not None:
        assert args.save_interval is not None
@@ -134,7 +136,7 @@ def _add_regularization_args(parser):
def _add_training_args(parser):
    group = parser.add_argument_group(title='training')

    group.add_argument('--batch-size', type=int, required=True,
    group.add_argument('--batch-size', type=int, default=None,
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
                       'parallel size.')
@@ -301,7 +303,7 @@ def _add_data_args(parser):
                       help='Path to the vocab file.')
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
    group.add_argument('--seq-length', type=int, required=True,
    group.add_argument('--seq-length', type=int, default=None,
                       help="Maximum sequence length to process.")
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
@@ -356,32 +358,6 @@ def _add_gpt2_args(parser):




def add_text_generate_args(parser):
    """Text generate arguments."""

    group = parser.add_argument_group('Text generation', 'configurations')
    group.add_argument("--temperature", type=float, default=1.0)
    group.add_argument("--greedy", action='store_true', default=False)
    group.add_argument("--top_p", type=float, default=0.0)
    group.add_argument("--top_k", type=int, default=0)
    group.add_argument("--out-seq-length", type=int, default=1024)
    group.add_argument("--sample-input-file", type=str, default="",
                      help='get input from file instead of interactive mode, '
                           'each line is an input' )
    group.add_argument("--sample-output-file", type=str, default="",
                      help='output file got from --sample-input-file')
    group.add_argument("--num-samples", type=int, default=0,
                       help='number of samples to generate unconditionally, '
                       'defaults to 0 and interactive conditional sampling')
    group.add_argument("--genfile", type=str,
                       help='output file when generating unconditionally')
    group.add_argument("--recompute", action='store_true',
                       help='during generation recompute all attention '
                       'instead of using previously computed keys/values.')
    return parser


def add_data_args_(parser):
    """Train/valid/test data arguments."""

+1 −2
Original line number Diff line number Diff line
@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
            self._binary_head_key = 'binary_head'


    def forward(self, input_ids, attention_mask,
                tokentype_ids=None):
    def forward(self, input_ids, attention_mask, tokentype_ids=None):

        extended_attention_mask = bert_extended_attention_mask(
            attention_mask, next(self.language_model.parameters()).dtype)
+6 −2
Original line number Diff line number Diff line
@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):


    def forward(self, input_ids, position_ids, attention_mask,
                tokentype_ids=None, layer_past=None, get_key_value=False):
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):

        # Language model.
        lm_output = self.language_model(input_ids,
@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
            lm_output, presents = lm_output

        # Output.
        parallel_output = self.parallel_output
        if forward_method_parallel_output is not None:
            parallel_output = forward_method_parallel_output
        output = parallel_lm_logits(
            lm_output,
            self.language_model.embedding.word_embeddings.weight,
            self.parallel_output)
            parallel_output)

        if get_key_value:
            output = [output, presents]
+411 −0

File changed and moved.File mode changed from 100755 to 100644.

Preview size limit exceeded, changes collapsed.

Loading