Loading megatron/api_server.py +5 −3 Original line number Diff line number Diff line Loading @@ -41,9 +41,11 @@ class MegatronGenerate(Resource): max_len = 64 # Choosing hopefully sane default. Full sequence is slow if "max_len" in request.get_json(): input_max_len = request.get_json()["max_len"] if input_max_len < args.seq_length: max_len = input_max_len max_len = request.get_json()["max_len"] if not isinstance(max_len, int): return "max_len must be an integer greater than 0" if max_len < 1: return "max_len must be an integer greater than 0" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences = generate(self.model, sentences, max_len) Loading megatron/text_generation_utils.py +20 −27 Original line number Diff line number Diff line Loading @@ -104,21 +104,6 @@ def tokenize_batch(sentences): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def get_token_stream(model, context_tokens_tensor, context_length_tensor): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids) for tokens, lengths in batch_token_iterator: context_length += 1 if tokens is not None: yield tokens[:, :context_length], lengths else: yield None, None def send_generate_info(context_tokens_tensor, context_length_tensor, max_len): """ Needs to be synced up with receive_generate_info Loading Loading @@ -151,13 +136,19 @@ def receive_generate_info(): return context_length_tensor, context_tokens_tensor, max_len def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len): token_stream = get_token_stream(model, context_tokens_tensor, context_length_tensor) for i, decode_tokens in enumerate(token_stream): if i == max_len-1: break pass return decode_tokens def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, max_len) for tokens, lengths in batch_token_iterator: context_length += 1 if tokens is not None: return tokens[:, :context_length] def generate(model, sentences=None, max_len=0): if torch.distributed.get_rank() == 0: Loading @@ -169,12 +160,11 @@ def generate(model, sentences=None, max_len=0): else: context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() decode_tokens = synced_generate(model, context_length_tensor, context_tokens_tensor, max_len) decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) if torch.distributed.get_rank() == 0: args = get_args() tokenizer = get_tokenizer() decode_tokens, _ = decode_tokens resp_sentences = [] for i in range(decode_tokens.size(0)): decode_token = decode_tokens[i,:].cpu().numpy().tolist() Loading Loading @@ -248,6 +238,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, tokens = context_tokens if maxlen is None: maxlen = args.seq_length - 1 maxlen = maxlen + org_context_length if maxlen > (org_context_length + args.out_seq_length): maxlen = org_context_length + args.out_seq_length Loading Loading
megatron/api_server.py +5 −3 Original line number Diff line number Diff line Loading @@ -41,9 +41,11 @@ class MegatronGenerate(Resource): max_len = 64 # Choosing hopefully sane default. Full sequence is slow if "max_len" in request.get_json(): input_max_len = request.get_json()["max_len"] if input_max_len < args.seq_length: max_len = input_max_len max_len = request.get_json()["max_len"] if not isinstance(max_len, int): return "max_len must be an integer greater than 0" if max_len < 1: return "max_len must be an integer greater than 0" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences = generate(self.model, sentences, max_len) Loading
megatron/text_generation_utils.py +20 −27 Original line number Diff line number Diff line Loading @@ -104,21 +104,6 @@ def tokenize_batch(sentences): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def get_token_stream(model, context_tokens_tensor, context_length_tensor): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids) for tokens, lengths in batch_token_iterator: context_length += 1 if tokens is not None: yield tokens[:, :context_length], lengths else: yield None, None def send_generate_info(context_tokens_tensor, context_length_tensor, max_len): """ Needs to be synced up with receive_generate_info Loading Loading @@ -151,13 +136,19 @@ def receive_generate_info(): return context_length_tensor, context_tokens_tensor, max_len def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len): token_stream = get_token_stream(model, context_tokens_tensor, context_length_tensor) for i, decode_tokens in enumerate(token_stream): if i == max_len-1: break pass return decode_tokens def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, max_len) for tokens, lengths in batch_token_iterator: context_length += 1 if tokens is not None: return tokens[:, :context_length] def generate(model, sentences=None, max_len=0): if torch.distributed.get_rank() == 0: Loading @@ -169,12 +160,11 @@ def generate(model, sentences=None, max_len=0): else: context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() decode_tokens = synced_generate(model, context_length_tensor, context_tokens_tensor, max_len) decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) if torch.distributed.get_rank() == 0: args = get_args() tokenizer = get_tokenizer() decode_tokens, _ = decode_tokens resp_sentences = [] for i in range(decode_tokens.size(0)): decode_token = decode_tokens[i,:].cpu().numpy().tolist() Loading Loading @@ -248,6 +238,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, tokens = context_tokens if maxlen is None: maxlen = args.seq_length - 1 maxlen = maxlen + org_context_length if maxlen > (org_context_length + args.out_seq_length): maxlen = org_context_length + args.out_seq_length Loading