Commit ddd36145 authored by rprenger's avatar rprenger
Browse files

Got the probs piped

parent 41df5ff7
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -47,8 +47,20 @@ class MegatronGenerate(Resource):
            if max_len < 1:
                return "max_len must be an integer greater than 0"

        all_probs = False
        if "all_probs" in request.get_json():
            all_probs = request.get_json()["all_probs"]
            if not isinstance(all_probs, bool):
                return "all_probs must be a boolean value"

        MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
        resp_sentences, resp_sentences_seg, output_logits = generate(self.model, sentences, max_len) 
        resp_sentences, resp_sentences_seg, output_logits, full_logits = generate(self.model, sentences, max_len, all_probs) 
        if all_probs:
            return jsonify({"sentences": resp_sentences,
                "segments": resp_sentences_seg,
                "logits": output_logits,
                "all_logits": full_logits})
        
        return jsonify({"sentences": resp_sentences,
            "segments": resp_sentences_seg,
            "logits": output_logits})
+43 −19
Original line number Diff line number Diff line
@@ -104,12 +104,12 @@ def tokenize_batch(sentences):
    context_length_tensor = torch.cuda.LongTensor(context_lengths)
    return context_tokens_tensor, context_length_tensor 

def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs):
    """
    Needs to be synced up with receive_generate_info
    """
    # Send the sizes of the tensors
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
    input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len, all_probs]
    input_info_tensor = torch.cuda.LongTensor(input_info)
    torch.distributed.broadcast(input_info_tensor, 0)

@@ -126,6 +126,7 @@ def receive_generate_info():
    batch_size = input_info_tensor[0].item()
    seq_len = input_info_tensor[1].item()
    max_len = input_info_tensor[2].item()
    all_probs = input_info_tensor[3].item()
    
    context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda"))
    context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda"))
@@ -134,23 +135,29 @@ def receive_generate_info():
    torch.distributed.broadcast(context_length_tensor, 0)
    torch.distributed.broadcast(context_tokens_tensor, 0)
    
    return context_length_tensor, context_tokens_tensor, max_len
    return context_length_tensor, context_tokens_tensor, max_len, all_probs

def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len):
def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 max_len)
    for tokens, lengths, output_logits in batch_token_iterator:
                                                 max_len,
                                                 all_probs)
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
        context_length += 1
                
    if mpu.is_pipeline_last_stage():
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        torch.distributed.broadcast(output_logits, src, group)
        if all_probs:
            src = mpu.get_pipeline_model_parallel_last_rank()
            group = mpu.get_embedding_group()
            torch.distributed.broadcast(full_logits, src, group)

    else:
        if mpu.is_pipeline_first_stage():
            src = mpu.get_pipeline_model_parallel_last_rank()
@@ -158,22 +165,28 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
            output_logits = torch.empty(tokens.size(0), context_length-1, dtype=torch.float32, device=torch.device("cuda"))
            torch.distributed.broadcast(output_logits, src, group)
            
            if all_probs:
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                full_logits = torch.empty(tokens.size(0), context_length, args.padded_vocab_size(), dtype=torch.float32, device=torch.device("cuda"))
                torch.distributed.broadcast(full_logits, src, group)
     
    if tokens is not None:
        return tokens[:, :context_length], output_logits 
        return tokens[:, :context_length], output_logits, full_logits 

def generate(model, sentences=None, max_len=0):
def generate(model, sentences=None, max_len=0, all_probs=False):
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
        c = context_length_tensor[0]
        b = context_tokens_tensor.size(0)
        start = time.time()
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
        send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
    else:
        context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
        context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info()
    
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len, all_probs)
    if output is not None:
        decode_tokens, output_logits = output
        decode_tokens, output_logits, full_logits = output

    if torch.distributed.get_rank() == 0:
        args = get_args()
@@ -191,9 +204,12 @@ def generate(model, sentences=None, max_len=0):
            resp_sentences_seg.append(words)

        output_logits = output_logits.cpu().numpy().tolist()
        if all_probs:
            full_logits = full_logits.cpu().numpy().tolist()

        end = time.time()
        print(str(b)+","+str(c)+","+str(decode_tokens.size(1))+","+str(end-start), flush=True)
        return resp_sentences, resp_sentences_seg, output_logits 
        return resp_sentences, resp_sentences_seg, output_logits, full_logits 

def switch(val1, val2, boolean):
    boolean = boolean.type_as(val1)
@@ -236,7 +252,7 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,

def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
                          maxlen=None, all_probs=False, type_ids=None):
    args = get_args()
    tokenizer = get_tokenizer()

@@ -318,12 +334,17 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    output_context = F.log_softmax(output[:, :context_length, :], 2)
                    indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
                    output_logits = torch.gather(output_context, 2, indices).squeeze(2)
                    if all_probs:
                        full_logits = output_context
                else:
                    output_context = F.log_softmax(output, 2)
                    indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
                    new_output_logits = torch.gather(F.log_softmax(output,2), 2, indices).squeeze(2)
                    new_output_logits = torch.gather(output_context, 2, indices).squeeze(2)
                    
                    # TODO(rprenger) we're copying output_logits every time.  Should pre-allocate
                    output_logits = torch.cat([output_logits, new_output_logits],1)
                    if all_probs:
                        full_logits = torch.cat([full_logits, output_context], 1)
                
                #output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1)
                src = mpu.get_pipeline_model_parallel_last_rank()
@@ -339,7 +360,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths, output_logits
                if all_probs:
                    yield tokens, lengths, output_logits, full_logits
                else:
                    yield tokens, lengths, output_logits, None

            else:
                if mpu.is_pipeline_first_stage():
@@ -348,9 +372,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None, None
                    yield tokens, None, None, None
                else:
                    yield None, None, None
                    yield None, None, None, None

                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()