Loading megatron/api_server.py +4 −3 Original line number Diff line number Diff line Loading @@ -48,9 +48,10 @@ class MegatronGenerate(Resource): return "max_len must be an integer greater than 0" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences = generate(self.model, sentences, max_len) return jsonify({"sentences": resp_sentences}) resp_sentences, resp_sentences_seg, output_logits = generate(self.model, sentences, max_len) return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits}) def index(): return current_app.send_static_file('index.html') Loading megatron/text_generation_utils.py +47 −9 Original line number Diff line number Diff line Loading @@ -144,11 +144,22 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len context_length_tensor, attention_mask, position_ids, max_len) for tokens, lengths in batch_token_iterator: for tokens, lengths, output_logits in batch_token_iterator: context_length += 1 if mpu.is_pipeline_last_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(output_logits, src, group) else: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(output_logits, src, group) if tokens is not None: return tokens[:, :context_length] return tokens[:, :context_length], output_logits def generate(model, sentences=None, max_len=0): if torch.distributed.get_rank() == 0: Loading @@ -160,18 +171,29 @@ def generate(model, sentences=None, max_len=0): else: context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) if output is not None: decode_tokens, output_logits = output if torch.distributed.get_rank() == 0: args = get_args() tokenizer = get_tokenizer() resp_sentences = [] resp_sentences_seg = [] for i in range(decode_tokens.size(0)): decode_token = decode_tokens[i,:].cpu().numpy().tolist() resp_sentences.append(tokenizer.detokenize(decode_token)) words = [] for token in decode_token: word = tokenizer.tokenizer.decoder[token] word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace') words.append(word) resp_sentences_seg.append(words) output_logits = output_logits.cpu().numpy().tolist() end = time.time() print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True) return resp_sentences return resp_sentences, resp_sentences_seg, output_logits def switch(val1, val2, boolean): boolean = boolean.type_as(val1) Loading Loading @@ -236,6 +258,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None if maxlen is None: maxlen = args.seq_length - 1 Loading @@ -261,6 +285,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) output, layer_past = forward_step(model, tokens2use, positions2use, attention_mask, Loading Loading @@ -288,6 +313,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths, new_tokens = switch( tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens if output_logits is None: output_context = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, :context_length],2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) else: indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) new_output_logits = torch.gather(F.log_softmax(output,2), 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat([output_logits, new_output_logits],1) #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) Loading @@ -301,7 +339,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) yield tokens, lengths yield tokens, lengths, output_logits else: if mpu.is_pipeline_first_stage(): Loading @@ -310,9 +348,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None yield tokens, None, None else: yield None, None yield None, None, None done = torch.cuda.ByteTensor([0]) src = mpu.get_pipeline_model_parallel_last_rank() Loading Loading
megatron/api_server.py +4 −3 Original line number Diff line number Diff line Loading @@ -48,9 +48,10 @@ class MegatronGenerate(Resource): return "max_len must be an integer greater than 0" MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences = generate(self.model, sentences, max_len) return jsonify({"sentences": resp_sentences}) resp_sentences, resp_sentences_seg, output_logits = generate(self.model, sentences, max_len) return jsonify({"sentences": resp_sentences, "segments": resp_sentences_seg, "logits": output_logits}) def index(): return current_app.send_static_file('index.html') Loading
megatron/text_generation_utils.py +47 −9 Original line number Diff line number Diff line Loading @@ -144,11 +144,22 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len context_length_tensor, attention_mask, position_ids, max_len) for tokens, lengths in batch_token_iterator: for tokens, lengths, output_logits in batch_token_iterator: context_length += 1 if mpu.is_pipeline_last_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(output_logits, src, group) else: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda")) torch.distributed.broadcast(output_logits, src, group) if tokens is not None: return tokens[:, :context_length] return tokens[:, :context_length], output_logits def generate(model, sentences=None, max_len=0): if torch.distributed.get_rank() == 0: Loading @@ -160,18 +171,29 @@ def generate(model, sentences=None, max_len=0): else: context_length_tensor, context_tokens_tensor, max_len = receive_generate_info() decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len) if output is not None: decode_tokens, output_logits = output if torch.distributed.get_rank() == 0: args = get_args() tokenizer = get_tokenizer() resp_sentences = [] resp_sentences_seg = [] for i in range(decode_tokens.size(0)): decode_token = decode_tokens[i,:].cpu().numpy().tolist() resp_sentences.append(tokenizer.detokenize(decode_token)) words = [] for token in decode_token: word = tokenizer.tokenizer.decoder[token] word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace') words.append(word) resp_sentences_seg.append(words) output_logits = output_logits.cpu().numpy().tolist() end = time.time() print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True) return resp_sentences return resp_sentences, resp_sentences_seg, output_logits def switch(val1, val2, boolean): boolean = boolean.type_as(val1) Loading Loading @@ -236,6 +258,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths, batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens output_logits = None if maxlen is None: maxlen = args.seq_length - 1 Loading @@ -261,6 +285,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) output, layer_past = forward_step(model, tokens2use, positions2use, attention_mask, Loading Loading @@ -288,6 +313,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths, new_tokens = switch( tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens if output_logits is None: output_context = F.log_softmax(output[:, :context_length, :], 2) indices = torch.unsqueeze(tokens[:, :context_length],2) output_logits = torch.gather(output_context, 2, indices).squeeze(2) else: indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) new_output_logits = torch.gather(F.log_softmax(output,2), 2, indices).squeeze(2) # TODO(rprenger) we're copying output_logits every time. Should pre-allocate output_logits = torch.cat([output_logits, new_output_logits],1) #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) Loading @@ -301,7 +339,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) yield tokens, lengths yield tokens, lengths, output_logits else: if mpu.is_pipeline_first_stage(): Loading @@ -310,9 +348,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None yield tokens, None, None else: yield None, None yield None, None, None done = torch.cuda.ByteTensor([0]) src = mpu.get_pipeline_model_parallel_last_rank() Loading