Loading megatron/api_server.py +13 −1 Original line number Diff line number Diff line Loading @@ -47,8 +47,20 @@ class MegatronGenerate(Resource): if max_len < 1: return "max_len must be an integer greater than 0" all_probs = False if "all_probs" in request.get_json(): all_probs = request.get_json()["all_probs"] if not isinstance(all_probs, bool): return "all_probs must be a boolean value" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits = generate(self.model, sentences, max_len) resp_sentences, resp_sentences_seg, output_logits, full_logits = generate(self.model, sentences, max_len, all_probs) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits, "all_logits": full_logits}) return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits}) Loading megatron/text_generation_utils.py +43 −19 Original line number Diff line number Diff line Loading @@ -104,12 +104,12 @@ def tokenize_batch(sentences): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def send_generate_info(context_tokens_tensor, context_length_tensor, max_len): def send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs): """ Needs to be synced up with receive_generate_info """ # Send the sizes of the tensors input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len, all_probs] input_info_tensor = torch.cuda.LongTensor(input_info) torch.distributed.broadcast(input_info_tensor, 0) Loading @@ -126,6 +126,7 @@ def receive_generate_info(): batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() max_len = input_info_tensor[2].item() all_probs = input_info_tensor[3].item() context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda")) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda")) Loading @@ -134,23 +135,29 @@ def receive_generate_info(): torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0) return context_length_tensor, context_tokens_tensor, max_len return context_length_tensor, context_tokens_tensor, max_len, all_probs def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len): def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, max_len) for tokens, lengths, output_logits in batch_token_iterator: max_len, all_probs) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 if mpu.is_pipeline_last_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(output_logits, src, group) if all_probs: src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(full_logits, src, group) else: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_last_rank() Loading @@ -158,22 +165,28 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(output_logits, src, group) if all_probs: src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size(), dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: return tokens[:, :context_length], output_logits return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, max_len=0): def generate(model, sentences=None, max_len=0, all_probs=False): if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) c = context_length_tensor[0] b = context_tokens_tensor.size(0) start = time.time() send_generate_info(context_tokens_tensor, context_length_tensor, max_len) send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs) else: context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs) if output is not None: decode_tokens, output_logits = output decode_tokens, output_logits, full_logits = output if torch.distributed.get_rank() == 0: args = get_args() Loading @@ -191,9 +204,12 @@ def generate(model, sentences=None, max_len=0): resp_sentences_seg.append(words) output_logits = output_logits.cpu().numpy().tolist() if all_probs: full_logits = full_logits.cpu().numpy().tolist() end = time.time() print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True) return resp_sentences, resp_sentences_seg, output_logits return resp_sentences, resp_sentences_seg, output_logits, full_logits def switch(val1, val2, boolean): boolean = boolean.type_as(val1) Loading Loading @@ -236,7 +252,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen=None, type_ids=None): maxlen=None, all_probs=False, type_ids=None): args = get_args() tokenizer = get_tokenizer() Loading Loading @@ -318,12 +334,17 @@ def sample_sequence_batch(model, context_tokens, context_lengths, output_context = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length+1],2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) if all_probs: full_logits = output_context else: output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) new_output_logits = torch.gather(F.log_softmax(output,2), 2, indices).squeeze(2) new_output_logits = torch.gather(output_context, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat([output_logits, new_output_logits],1) if all_probs: full_logits = torch.cat([full_logits, output_context], 1) #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1) src = mpu.get_pipeline_model_parallel_last_rank() Loading @@ -339,7 +360,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths, src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) yield tokens, lengths, output_logits if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if mpu.is_pipeline_first_stage(): Loading @@ -348,9 +372,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None yield tokens, None, None, None else: yield None, None, None yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = mpu.get_pipeline_model_parallel_last_rank() Loading Loading
megatron/api_server.py +13 −1 Original line number Diff line number Diff line Loading @@ -47,8 +47,20 @@ class MegatronGenerate(Resource): if max_len < 1: return "max_len must be an integer greater than 0" all_probs = False if "all_probs" in request.get_json(): all_probs = request.get_json()["all_probs"] if not isinstance(all_probs, bool): return "all_probs must be a boolean value" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits = generate(self.model, sentences, max_len) resp_sentences, resp_sentences_seg, output_logits, full_logits = generate(self.model, sentences, max_len, all_probs) if all_probs: return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits, "all_logits": full_logits}) return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits}) Loading
megatron/text_generation_utils.py +43 −19 Original line number Diff line number Diff line Loading @@ -104,12 +104,12 @@ def tokenize_batch(sentences): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def send_generate_info(context_tokens_tensor, context_length_tensor, max_len): def send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs): """ Needs to be synced up with receive_generate_info """ # Send the sizes of the tensors input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len] input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len, all_probs] input_info_tensor = torch.cuda.LongTensor(input_info) torch.distributed.broadcast(input_info_tensor, 0) Loading @@ -126,6 +126,7 @@ def receive_generate_info(): batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() max_len = input_info_tensor[2].item() all_probs = input_info_tensor[3].item() context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda")) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda")) Loading @@ -134,23 +135,29 @@ def receive_generate_info(): torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0) return context_length_tensor, context_tokens_tensor, max_len return context_length_tensor, context_tokens_tensor, max_len, all_probs def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len): def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, max_len) for tokens, lengths, output_logits in batch_token_iterator: max_len, all_probs) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 if mpu.is_pipeline_last_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(output_logits, src, group) if all_probs: src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(full_logits, src, group) else: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_last_rank() Loading @@ -158,22 +165,28 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(output_logits, src, group) if all_probs: src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size(), dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(full_logits, src, group) if tokens is not None: return tokens[:, :context_length], output_logits return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, max_len=0): def generate(model, sentences=None, max_len=0, all_probs=False): if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences) c = context_length_tensor[0] b = context_tokens_tensor.size(0) start = time.time() send_generate_info(context_tokens_tensor, context_length_tensor, max_len) send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs) else: context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs) if output is not None: decode_tokens, output_logits = output decode_tokens, output_logits, full_logits = output if torch.distributed.get_rank() == 0: args = get_args() Loading @@ -191,9 +204,12 @@ def generate(model, sentences=None, max_len=0): resp_sentences_seg.append(words) output_logits = output_logits.cpu().numpy().tolist() if all_probs: full_logits = full_logits.cpu().numpy().tolist() end = time.time() print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True) return resp_sentences, resp_sentences_seg, output_logits return resp_sentences, resp_sentences_seg, output_logits, full_logits def switch(val1, val2, boolean): boolean = boolean.type_as(val1) Loading Loading @@ -236,7 +252,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen=None, type_ids=None): maxlen=None, all_probs=False, type_ids=None): args = get_args() tokenizer = get_tokenizer() Loading Loading @@ -318,12 +334,17 @@ def sample_sequence_batch(model, context_tokens, context_lengths, output_context = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, 1:context_length+1],2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) if all_probs: full_logits = output_context else: output_context = F.log_softmax(output, 2) indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) new_output_logits = torch.gather(F.log_softmax(output,2), 2, indices).squeeze(2) new_output_logits = torch.gather(output_context, 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat([output_logits, new_output_logits],1) if all_probs: full_logits = torch.cat([full_logits, output_context], 1) #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1) src = mpu.get_pipeline_model_parallel_last_rank() Loading @@ -339,7 +360,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths, src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) yield tokens, lengths, output_logits if all_probs: yield tokens, lengths, output_logits, full_logits else: yield tokens, lengths, output_logits, None else: if mpu.is_pipeline_first_stage(): Loading @@ -348,9 +372,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None, None yield tokens, None, None, None else: yield None, None, None yield None, None, None, None done = torch.cuda.ByteTensor([0]) src = mpu.get_pipeline_model_parallel_last_rank() Loading