Loading generate_samples.py 100755 → 100644 +72 −82 Original line number Diff line number Diff line Loading @@ -15,34 +15,27 @@ """Sample Generate GPT2""" import os import random import json import copy import numpy as np import json import os import time import torch import torch.nn.functional as F import argparse import time from arguments import get_args from megatron.utils import Timers from megatron.utils import initialize_distributed from megatron.utils import set_random_seed from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import load_checkpoint from megatron.data_utils import make_tokenizer from configure_data import configure_data from megatron import mpu from megatron.fp16 import FP16_Module from megatron.model import GPT2Model from megatron.model import DistributedDataParallel as DDP from megatron import get_args from megatron import get_tokenizer from megatron import mpu from megatron import print_rank_0 from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from megatron.model import GPT2Model from megatron.training import get_model from megatron.utils import get_ltor_masks_and_position_ids def model_provider(): """Build the model.""" args = get_args() print_rank_0('building GPT2 model ...') model = GPT2Model(num_tokentypes=0, parallel_output=False) Loading @@ -56,7 +49,7 @@ def get_batch(context_tokens): tokenizer = get_tokenizer() # Move to GPU. tokens = context_tokens.view(args.batch_size, -1)..contiguous().cuda() tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda() # Get the attention mask and postition ids. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, Loading Loading @@ -103,7 +96,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): def generate_samples_input_from_file(model): """XXX""" args = get_args() tokenizer = get_tokenizer() Loading @@ -118,7 +111,7 @@ def generate_samples_input_from_file(model): if args.sample_output_file is None: sample_output_file = args.sample_input_file + ".out" print('could not find `sample-output-file`, setting ' 'it to {}'.formatsample_output_file()) 'it to {}'.format(sample_output_file)) fname_out = open(sample_output_file, "w+") context_count = 0 Loading Loading @@ -158,9 +151,8 @@ def generate_samples_input_from_file(model): if terminate_runs == 1: return start_time = time.time() token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): for _, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() Loading @@ -184,7 +176,7 @@ def generate_samples_input_from_file(model): def generate_samples_interactive(model, print_frequency=24): """XXX""" args = get_args() tokenizer = get_tokenizer() Loading Loading @@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24): if terminate_runs == 1: return start_time = time.time() token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens Loading Loading @@ -256,14 +247,13 @@ def generate_samples_interactive(model, print_frequency=24): def generate_samples_unconditional(model): """XXX""" args = get_args() tokenizer = get_tokenizer() num_samples = args.num_samples context_tokens = [[tokenizer.eod] for _ in range(args.batch_size)] samples = [] ctr = 0 while True: start_time = time.time() Loading Loading @@ -291,6 +281,7 @@ def generate_samples_unconditional(model): def write_and_generate_samples_unconditional(model): args = get_args() assert args.genfile is not None with open(args.genfile, 'w') as f: Loading @@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model): f.write(json.dumps(datum)+'\n') def pad_batch(batch, tokenizer, args): pad_id = tokenizer.eod def pad_batch(batch, pad_id, args): context_lengths = [] for tokens in batch: context_length = len(tokens) Loading @@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args): def get_token_stream(model, context_tokens): args = get_args() tokenizer = get_tokenizer() pad_id = tokenizer.eod context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args) context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) Loading @@ -329,11 +321,6 @@ def get_token_stream(model, context_tokens): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) counter = 0 org_context_length = context_length layer_past = None batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids) Loading @@ -343,6 +330,7 @@ def get_token_stream(model, context_tokens): def switch(val1, val2, boolean): boolean = boolean.type_as(val1) return (1 - boolean) * val1 + boolean * val2 Loading @@ -350,14 +338,14 @@ def switch(val1, val2, boolean): def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen=None, type_ids=None): """XXX""" args = get_args() tokenizer = get_tokenizer() model.eval() with torch.no_grad(): context_length = context_lengths.min().item() eos_id = tokenizer.get_command('eos').Id eos_id = tokenizer.eod counter = 0 org_context_length = context_length Loading Loading @@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, done_token = (prev == eos_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length was_done = is_done is_done = is_done | done_token done = torch.all(is_done) Loading @@ -438,46 +425,51 @@ def sample_sequence_batch(model, context_tokens, context_lengths, break def add_text_generate_args(parser): """Text generate arguments.""" group = parser.add_argument_group('Text generation', 'configurations') group.add_argument("--temperature", type=float, default=1.0) group.add_argument("--greedy", action='store_true', default=False) group.add_argument("--top_p", type=float, default=0.0) group.add_argument("--top_k", type=int, default=0) group.add_argument("--out-seq-length", type=int, default=1024) """Text generation arguments.""" group = parser.add_argument_group(title='text generation') group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') group.add_argument("--greedy", action='store_true', default=False, help='Use greedy sampling.') group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') group.add_argument("--top_k", type=int, default=0, help='Top k sampling.') group.add_argument("--out-seq-length", type=int, default=1024, help='Size of the output generated text.') group.add_argument("--sample-input-file", type=str, default=None, help='get input from file instead of interactive mode, ' 'each line is an input' ) help='Get input from file instead of interactive mode, ' 'each line is an input.') group.add_argument("--sample-output-file", type=str, default=None, help='output file got from --sample-input-file') help='Output file got from --sample-input-file') group.add_argument("--num-samples", type=int, default=0, help='number of samples to generate unconditionally, ' help='Number of samples to generate unconditionally, ' 'defaults to 0 and interactive conditional sampling') group.add_argument("--genfile", type=str, help='output file when generating unconditionally') help='Output file when generating unconditionally') group.add_argument("--recompute", action='store_true', help='during generation recompute all attention ' help='During generation recompute all attention ' 'instead of using previously computed keys/values.') return parser def main(): """Main program.""" print('Generate Samples') initialize_megatron(extra_args_provider=add_text_generate_args, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) # Set up model and load checkpoint. model = get_model(model_provider) args = get_args() if args.load is not None: _ = load_checkpoint(model, None, None) #generate samples # Generate samples. if args.num_samples == 0: args.batch_size = 1 assert args.batch_size == 1 if args.sample_input_file != "": generate_samples_input_from_file(model) else: Loading @@ -487,7 +479,5 @@ def main(): if __name__ == "__main__": main() main() pretrain_gpt2.py 100755 → 100644 +0 −0 File mode changed from 100755 to 100644. View file Loading
generate_samples.py 100755 → 100644 +72 −82 Original line number Diff line number Diff line Loading @@ -15,34 +15,27 @@ """Sample Generate GPT2""" import os import random import json import copy import numpy as np import json import os import time import torch import torch.nn.functional as F import argparse import time from arguments import get_args from megatron.utils import Timers from megatron.utils import initialize_distributed from megatron.utils import set_random_seed from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import load_checkpoint from megatron.data_utils import make_tokenizer from configure_data import configure_data from megatron import mpu from megatron.fp16 import FP16_Module from megatron.model import GPT2Model from megatron.model import DistributedDataParallel as DDP from megatron import get_args from megatron import get_tokenizer from megatron import mpu from megatron import print_rank_0 from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron from megatron.model import GPT2Model from megatron.training import get_model from megatron.utils import get_ltor_masks_and_position_ids def model_provider(): """Build the model.""" args = get_args() print_rank_0('building GPT2 model ...') model = GPT2Model(num_tokentypes=0, parallel_output=False) Loading @@ -56,7 +49,7 @@ def get_batch(context_tokens): tokenizer = get_tokenizer() # Move to GPU. tokens = context_tokens.view(args.batch_size, -1)..contiguous().cuda() tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda() # Get the attention mask and postition ids. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, Loading Loading @@ -103,7 +96,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): def generate_samples_input_from_file(model): """XXX""" args = get_args() tokenizer = get_tokenizer() Loading @@ -118,7 +111,7 @@ def generate_samples_input_from_file(model): if args.sample_output_file is None: sample_output_file = args.sample_input_file + ".out" print('could not find `sample-output-file`, setting ' 'it to {}'.formatsample_output_file()) 'it to {}'.format(sample_output_file)) fname_out = open(sample_output_file, "w+") context_count = 0 Loading Loading @@ -158,9 +151,8 @@ def generate_samples_input_from_file(model): if terminate_runs == 1: return start_time = time.time() token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): for _, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens decode_tokens = decode_tokens[0].cpu().numpy().tolist() Loading @@ -184,7 +176,7 @@ def generate_samples_input_from_file(model): def generate_samples_interactive(model, print_frequency=24): """XXX""" args = get_args() tokenizer = get_tokenizer() Loading Loading @@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24): if terminate_runs == 1: return start_time = time.time() token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): decode_tokens, _ = decode_tokens Loading Loading @@ -256,14 +247,13 @@ def generate_samples_interactive(model, print_frequency=24): def generate_samples_unconditional(model): """XXX""" args = get_args() tokenizer = get_tokenizer() num_samples = args.num_samples context_tokens = [[tokenizer.eod] for _ in range(args.batch_size)] samples = [] ctr = 0 while True: start_time = time.time() Loading Loading @@ -291,6 +281,7 @@ def generate_samples_unconditional(model): def write_and_generate_samples_unconditional(model): args = get_args() assert args.genfile is not None with open(args.genfile, 'w') as f: Loading @@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model): f.write(json.dumps(datum)+'\n') def pad_batch(batch, tokenizer, args): pad_id = tokenizer.eod def pad_batch(batch, pad_id, args): context_lengths = [] for tokens in batch: context_length = len(tokens) Loading @@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args): def get_token_stream(model, context_tokens): args = get_args() tokenizer = get_tokenizer() pad_id = tokenizer.eod context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args) context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) Loading @@ -329,11 +321,6 @@ def get_token_stream(model, context_tokens): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args) counter = 0 org_context_length = context_length layer_past = None batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids) Loading @@ -343,6 +330,7 @@ def get_token_stream(model, context_tokens): def switch(val1, val2, boolean): boolean = boolean.type_as(val1) return (1 - boolean) * val1 + boolean * val2 Loading @@ -350,14 +338,14 @@ def switch(val1, val2, boolean): def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen=None, type_ids=None): """XXX""" args = get_args() tokenizer = get_tokenizer() model.eval() with torch.no_grad(): context_length = context_lengths.min().item() eos_id = tokenizer.get_command('eos').Id eos_id = tokenizer.eod counter = 0 org_context_length = context_length Loading Loading @@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, done_token = (prev == eos_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length was_done = is_done is_done = is_done | done_token done = torch.all(is_done) Loading @@ -438,46 +425,51 @@ def sample_sequence_batch(model, context_tokens, context_lengths, break def add_text_generate_args(parser): """Text generate arguments.""" group = parser.add_argument_group('Text generation', 'configurations') group.add_argument("--temperature", type=float, default=1.0) group.add_argument("--greedy", action='store_true', default=False) group.add_argument("--top_p", type=float, default=0.0) group.add_argument("--top_k", type=int, default=0) group.add_argument("--out-seq-length", type=int, default=1024) """Text generation arguments.""" group = parser.add_argument_group(title='text generation') group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') group.add_argument("--greedy", action='store_true', default=False, help='Use greedy sampling.') group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') group.add_argument("--top_k", type=int, default=0, help='Top k sampling.') group.add_argument("--out-seq-length", type=int, default=1024, help='Size of the output generated text.') group.add_argument("--sample-input-file", type=str, default=None, help='get input from file instead of interactive mode, ' 'each line is an input' ) help='Get input from file instead of interactive mode, ' 'each line is an input.') group.add_argument("--sample-output-file", type=str, default=None, help='output file got from --sample-input-file') help='Output file got from --sample-input-file') group.add_argument("--num-samples", type=int, default=0, help='number of samples to generate unconditionally, ' help='Number of samples to generate unconditionally, ' 'defaults to 0 and interactive conditional sampling') group.add_argument("--genfile", type=str, help='output file when generating unconditionally') help='Output file when generating unconditionally') group.add_argument("--recompute", action='store_true', help='during generation recompute all attention ' help='During generation recompute all attention ' 'instead of using previously computed keys/values.') return parser def main(): """Main program.""" print('Generate Samples') initialize_megatron(extra_args_provider=add_text_generate_args, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) # Set up model and load checkpoint. model = get_model(model_provider) args = get_args() if args.load is not None: _ = load_checkpoint(model, None, None) #generate samples # Generate samples. if args.num_samples == 0: args.batch_size = 1 assert args.batch_size == 1 if args.sample_input_file != "": generate_samples_input_from_file(model) else: Loading @@ -487,7 +479,5 @@ def main(): if __name__ == "__main__": main() main()