Loading megatron/text_generation_utils.py +8 −7 Original line number Diff line number Diff line Loading @@ -85,21 +85,22 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): return logits def pad_batch(batch, pad_id, args): def pad_batch(batch, pad_id, max_len): context_lengths = [] max_context_length = max([len(tokens) for tokens in batch]) for tokens in batch: context_length = len(tokens) if context_length < args.seq_length: tokens.extend([pad_id] * (args.seq_length - context_length)) if context_length < max_context_length + max_len: tokens.extend([pad_id] * (max_context_length + max_len - context_length)) context_lengths.append(context_length) return batch, context_lengths def tokenize_batch(sentences): def tokenize_batch(sentences, max_len): args = get_args() tokenizer = get_tokenizer() context_tokens = [tokenizer.tokenize(s) for s in sentences] context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args) tokenizer.eod, max_len) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor Loading Loading @@ -178,12 +179,13 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) if output is not None: decode_tokens, output_logits, full_logits = output Loading Loading @@ -290,7 +292,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() # TODO(rprenger) Need a better understanding of what args.seq_length vs args.out_seq_length (shouldn't be "args") if maxlen > args.seq_length: maxlen = args.seq_length Loading Loading
megatron/text_generation_utils.py +8 −7 Original line number Diff line number Diff line Loading @@ -85,21 +85,22 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): return logits def pad_batch(batch, pad_id, args): def pad_batch(batch, pad_id, max_len): context_lengths = [] max_context_length = max([len(tokens) for tokens in batch]) for tokens in batch: context_length = len(tokens) if context_length < args.seq_length: tokens.extend([pad_id] * (args.seq_length - context_length)) if context_length < max_context_length + max_len: tokens.extend([pad_id] * (max_context_length + max_len - context_length)) context_lengths.append(context_length) return batch, context_lengths def tokenize_batch(sentences): def tokenize_batch(sentences, max_len): args = get_args() tokenizer = get_tokenizer() context_tokens = [tokenizer.tokenize(s) for s in sentences] context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args) tokenizer.eod, max_len) context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor Loading Loading @@ -178,12 +179,13 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) if output is not None: decode_tokens, output_logits, full_logits = output Loading Loading @@ -290,7 +292,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, # Generate enough tokens for the longest sequence maxlen = tokens_to_generate + context_lengths.max().item() # TODO(rprenger) Need a better understanding of what args.seq_length vs args.out_seq_length (shouldn't be "args") if maxlen > args.seq_length: maxlen = args.seq_length Loading