Commit e0bf5199 authored by rprenger's avatar rprenger
Browse files

Outputting log probabilities

parent 279d8320
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -48,9 +48,10 @@ class MegatronGenerate(Resource):
                return "max_len must be an integer greater than 0"

        MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
        resp_sentences = generate(self.model, sentences, max_len) 
        return jsonify({"sentences": resp_sentences})

        resp_sentences, resp_sentences_seg, output_logits = generate(self.model, sentences, max_len) 
        return jsonify({"sentences": resp_sentences,
            "segments": resp_sentences_seg,
            "logits": output_logits})

def index():
    return current_app.send_static_file('index.html')
+47 −9
Original line number Diff line number Diff line
@@ -144,11 +144,22 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 max_len)
    for tokens, lengths in batch_token_iterator:
    for tokens, lengths, output_logits in batch_token_iterator:
        context_length += 1
                
    if mpu.is_pipeline_last_stage():
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        torch.distributed.broadcast(output_logits, src, group)
    else:
        if mpu.is_pipeline_first_stage():
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
            torch.distributed.broadcast(output_logits, src, group)
        
    if tokens is not None:
        return tokens[:, :context_length]
        return tokens[:, :context_length], output_logits 

def generate(model, sentences=None, max_len=0):
    if torch.distributed.get_rank() == 0:
@@ -160,18 +171,29 @@ def generate(model, sentences=None, max_len=0):
    else:
        context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
    
    decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
    if output is not None:
        decode_tokens, output_logits = output

    if torch.distributed.get_rank() == 0:
        args = get_args()
        tokenizer = get_tokenizer()
        resp_sentences = []
        resp_sentences_seg = []
        for i in range(decode_tokens.size(0)):
            decode_token = decode_tokens[i,:].cpu().numpy().tolist()
            resp_sentences.append(tokenizer.detokenize(decode_token))
            words = []
            for token in decode_token:
                word = tokenizer.tokenizer.decoder[token]
                word = bytearray([tokenizer.tokenizer.byte_decoder[c] for c in word]).decode('utf-8', errors='replace')
                words.append(word)
            resp_sentences_seg.append(words)

        output_logits = output_logits.cpu().numpy().tolist()
        end = time.time()
        print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True)
        return resp_sentences
        return resp_sentences, resp_sentences_seg, output_logits 

def switch(val1, val2, boolean):
    boolean = boolean.type_as(val1)
@@ -236,6 +258,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        output_logits = None

        if maxlen is None:
            maxlen = args.seq_length - 1
        
@@ -261,6 +285,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                if type_ids is not None:
                    types2use = type_ids[:, context_length - 1].view(
                        batch_size, -1)
            
            output, layer_past = forward_step(model, tokens2use,
                                              positions2use,
                                              attention_mask,
@@ -288,6 +313,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                
                if output_logits is None:
                    output_context = F.log_softmax(output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, :context_length],2)
                    output_logits = torch.gather(output_context, 2, indices).squeeze(2)
                else:
                    indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
                    new_output_logits = torch.gather(F.log_softmax(output,2), 2, indices).squeeze(2)
                    
                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat([output_logits, new_output_logits],1)
                
                #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)
@@ -301,7 +339,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths
                yield tokens, lengths, output_logits

            else:
                if mpu.is_pipeline_first_stage():
@@ -310,9 +348,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                    yield tokens, None, None
                else:
                    yield None, None
                    yield None, None, None

                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()