Commit a6ba254f authored by Mohammad's avatar Mohammad
Browse files

generate samples linted

parent a19820b1
Loading
Loading
Loading
Loading

generate_samples.py

100755 → 100644
+72 −82
Original line number Diff line number Diff line
@@ -15,34 +15,27 @@

"""Sample Generate GPT2"""

import os
import random
import json
import copy
import numpy as np
import json
import os
import time

import torch
import torch.nn.functional as F
import argparse
import time
from arguments import get_args
from megatron.utils import Timers
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
from configure_data import configure_data
from megatron import mpu

from megatron.fp16 import FP16_Module
from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids


def model_provider():
    """Build the model."""
    args = get_args()

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(num_tokentypes=0, parallel_output=False)
@@ -56,7 +49,7 @@ def get_batch(context_tokens):
    tokenizer = get_tokenizer()

    # Move to GPU.
    tokens = context_tokens.view(args.batch_size, -1)..contiguous().cuda()
    tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda()
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        tokens,
@@ -103,7 +96,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):


def generate_samples_input_from_file(model):
    """XXX"""

    args = get_args()
    tokenizer = get_tokenizer()

@@ -118,7 +111,7 @@ def generate_samples_input_from_file(model):
        if args.sample_output_file is None:
            sample_output_file = args.sample_input_file + ".out"
            print('could not find `sample-output-file`, setting '
                  'it to {}'.formatsample_output_file())
                  'it to {}'.format(sample_output_file))
        fname_out = open(sample_output_file, "w+")

    context_count = 0
@@ -158,9 +151,8 @@ def generate_samples_input_from_file(model):
            if terminate_runs == 1:
                return

            start_time = time.time()
            token_stream = get_token_stream(model, [context_tokens])
            for counter, decode_tokens in enumerate(token_stream):
            for _, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()

@@ -184,7 +176,7 @@ def generate_samples_input_from_file(model):


def generate_samples_interactive(model, print_frequency=24):
    """XXX"""

    args = get_args()
    tokenizer = get_tokenizer()

@@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24):
            if terminate_runs == 1:
                return

            start_time = time.time()
            token_stream = get_token_stream(model, [context_tokens])
            for counter, decode_tokens in enumerate(token_stream):
                decode_tokens, _ = decode_tokens
@@ -256,14 +247,13 @@ def generate_samples_interactive(model, print_frequency=24):


def generate_samples_unconditional(model):
    """XXX"""

    args = get_args()
    tokenizer = get_tokenizer()

    num_samples = args.num_samples
    context_tokens = [[tokenizer.eod]
                      for _ in range(args.batch_size)]
    samples = []
    ctr = 0
    while True:
        start_time = time.time()
@@ -291,6 +281,7 @@ def generate_samples_unconditional(model):


def write_and_generate_samples_unconditional(model):

    args = get_args()
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
@@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model):
            f.write(json.dumps(datum)+'\n')


def pad_batch(batch, tokenizer, args):
    pad_id = tokenizer.eod
def pad_batch(batch, pad_id, args):

    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
@@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args):


def get_token_stream(model, context_tokens):

    args = get_args()
    tokenizer = get_tokenizer()

    pad_id = tokenizer.eod
    context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
@@ -329,11 +321,6 @@ def get_token_stream(model, context_tokens):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)

    counter = 0
    org_context_length = context_length

    layer_past = None

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
@@ -343,6 +330,7 @@ def get_token_stream(model, context_tokens):


def switch(val1, val2, boolean):

    boolean = boolean.type_as(val1)
    return (1 - boolean) * val1 + boolean * val2

@@ -350,14 +338,14 @@ def switch(val1, val2, boolean):
def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
    """XXX"""

    args = get_args()
    tokenizer = get_tokenizer()

    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
        eos_id = tokenizer.get_command('eos').Id
        eos_id = tokenizer.eod

        counter = 0
        org_context_length = context_length
@@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
            done_token = (prev == eos_id).byte() & started.byte()
            just_finished = (done_token & ~is_done).bool()
            lengths[just_finished.view(-1)] = context_length
            was_done = is_done
            is_done = is_done | done_token
            done = torch.all(is_done)

@@ -438,46 +425,51 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                break

def add_text_generate_args(parser):
    """Text generate arguments."""

    group = parser.add_argument_group('Text generation', 'configurations')
    group.add_argument("--temperature", type=float, default=1.0)
    group.add_argument("--greedy", action='store_true', default=False)
    group.add_argument("--top_p", type=float, default=0.0)
    group.add_argument("--top_k", type=int, default=0)
    group.add_argument("--out-seq-length", type=int, default=1024)
    """Text generation arguments."""
    group = parser.add_argument_group(title='text generation')

    group.add_argument("--temperature", type=float, default=1.0,
                       help='Sampling temperature.')
    group.add_argument("--greedy", action='store_true', default=False,
                       help='Use greedy sampling.')
    group.add_argument("--top_p", type=float, default=0.0,
                       help='Top p sampling.')
    group.add_argument("--top_k", type=int, default=0,
                       help='Top k sampling.')
    group.add_argument("--out-seq-length", type=int, default=1024,
                       help='Size of the output generated text.')
    group.add_argument("--sample-input-file", type=str, default=None,
                      help='get input from file instead of interactive mode, '
                           'each line is an input' )
                       help='Get input from file instead of interactive mode, '
                       'each line is an input.')
    group.add_argument("--sample-output-file", type=str, default=None,
                      help='output file got from --sample-input-file')
                       help='Output file got from --sample-input-file')
    group.add_argument("--num-samples", type=int, default=0,
                       help='number of samples to generate unconditionally, '
                       help='Number of samples to generate unconditionally, '
                       'defaults to 0 and interactive conditional sampling')
    group.add_argument("--genfile", type=str,
                       help='output file when generating unconditionally')
                       help='Output file when generating unconditionally')
    group.add_argument("--recompute", action='store_true',
                       help='during generation recompute all attention '
                       help='During generation recompute all attention '
                       'instead of using previously computed keys/values.')

    return parser


def main():
    """Main program."""

    print('Generate Samples')

    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})

    # Set up model and load checkpoint.
    model = get_model(model_provider)
    args = get_args()
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    #generate samples
    # Generate samples.
    if args.num_samples == 0:
        args.batch_size = 1
        assert args.batch_size == 1
        if args.sample_input_file != "":
            generate_samples_input_from_file(model)
        else:
@@ -487,7 +479,5 @@ def main():


if __name__ == "__main__":
    main()



    main()

pretrain_gpt2.py

100755 → 100644
+0 −0

File mode changed from 100755 to 100644.