Commit d33460df authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'just_score' into 'main'

This adds a function for the case where the user only wants the log-probabilities "tokens_to_generate=0".

See merge request ADLR/megatron-lm!345
parents 4145f091 d7699436
Loading
Loading
Loading
Loading
+12 −8
Original line number Diff line number Diff line
@@ -20,13 +20,13 @@ import torch

from megatron import mpu
from .communication import broadcast_float_list
from .generation import generate_tokens_probs_and_return_on_first_stage
from .generation import (
        generate_tokens_probs_and_return_on_first_stage,
        score_and_return_on_first_stage)
from .tokenization import (
    tokenize_prompts,
    detokenize_generations)



def generate_and_post_process(model,
                              prompts=None,
                              tokens_to_generate=0,
@@ -53,20 +53,19 @@ def generate_and_post_process(model,

    # Only post-process on first stage.
    if mpu.is_pipeline_first_stage():

        tokens, prompts_plus_generations, prompts_plus_generations_segments = \
            detokenize_generations(tokens, lengths, True)

        if return_output_log_probs:
            output_log_probs = output_log_probs.cpu().numpy().tolist()
            for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
                output_log_probs[i] = prob[:len(seg)-1]

        return prompts_plus_generations, prompts_plus_generations_segments, \
            output_log_probs, tokens

    return None



def generate(model,
             prompts=None,
             tokens_to_generate=0,
@@ -85,7 +84,8 @@ def generate(model,
    """

    # Make sure input params are avaialble to all ranks.
    values = [tokens_to_generate, return_output_log_probs,
    values = [tokens_to_generate,
              return_output_log_probs,
              top_k_sampling, top_p_sampling,
              temperature, add_BOS, use_eod_token_for_early_termination]
    values_float_tensor = broadcast_float_list(7, float_list=values)
@@ -101,10 +101,14 @@ def generate(model,
    # Note that these tensors are broadcaseted to all ranks.
    if torch.distributed.get_rank() == 0:
        assert prompts is not None
        assert tokens_to_generate > 0
    
    context_tokens_tensor, context_length_tensor = tokenize_prompts(
        prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)

    if tokens_to_generate == 0:
        return score_and_return_on_first_stage(
            model, context_tokens_tensor, context_length_tensor)

    # Main inference function.
    # Note that the outputs are available on the first stage.
    return generate_tokens_probs_and_return_on_first_stage(
+66 −2
Original line number Diff line number Diff line
@@ -27,6 +27,69 @@ from .communication import (
from .forward_step import ForwardStep
from .sampling import sample

def score_and_return_on_first_stage(model, tokens, lengths):
    """Function for just scoring.
    Arguments:
        model: no interleaving is supported.
        tokens: prompt tokens extended to be of size [b, max_prompt_length]
        lengths: original prompt length, size: [b]
    Note: Outside of model, other parameters only need to be available on
          rank 0.
    Outputs: 
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """

    args = get_args()

    batch_size = tokens.size(0)
    max_prompt_length = lengths.max().item()
    assert max_prompt_length == tokens.size(1)
    max_sequence_length = min(max_prompt_length, args.max_position_embeddings)

    # forward step.
    forward_step = ForwardStep(model, batch_size, max_sequence_length)

    # ===================
    # Pre-allocate memory
    # ===================

    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    
    if mpu.is_pipeline_last_stage():
        output_log_probs = torch.empty(output_log_probs_size,
                                       dtype=torch.float32,
                                       device=torch.cuda.current_device())
    
    # =============
    # Run infernece
    # =============
    with torch.no_grad():
        attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
        
        # logits will be meanigful only in the last pipeline stage.
        logits = forward_step(tokens, position_ids, attention_mask)

        if mpu.is_pipeline_last_stage():
            # Always the last stage should have an output.
            assert logits is not None
            log_probs = F.log_softmax(logits, dim=2)
            
            # Pick the tokens that we need to get the log
            # probabilities for. Note that next input token is
            # the token which we selected in the current logits,
            # so shift by 1.
            indices = torch.unsqueeze(tokens[:, 1:], 2)
            output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
    
    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================
    output_log_probs = broadcast_from_last_to_first_pipeline_stage(
        output_log_probs_size, torch.float32, output_log_probs)
    
    return tokens, lengths, output_log_probs

def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
@@ -95,6 +158,7 @@ def generate_tokens_probs_and_return_on_first_stage(
        generated_sequence_lengths = torch.ones(
                batch_size, dtype=torch.int64,
                device=torch.cuda.current_device()) * max_sequence_length
    
    # Whether we have reached a termination id.
    is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
                                     device=torch.cuda.current_device())
+7 −4
Original line number Diff line number Diff line
@@ -58,8 +58,8 @@ class MegatronGenerate(Resource):
            tokens_to_generate = request.get_json()["tokens_to_generate"]
            if not isinstance(tokens_to_generate, int):
                return "tokens_to_generate must be an integer greater than 0"
            if tokens_to_generate < 1:
                return "tokens_to_generate must be an integer greater than 0"
            if tokens_to_generate < 0:
                return "tokens_to_generate must be an integer greater than or equal to 0"

        logprobs = False
        if "logprobs" in request.get_json():
@@ -67,6 +67,9 @@ class MegatronGenerate(Resource):
            if not isinstance(logprobs, bool):
                return "logprobs must be a boolean value"
        
        if tokens_to_generate == 0 and not logprobs:
            return "tokens_to_generate=0 implies logprobs should be True"
        
        temperature = 1.0
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
@@ -80,7 +83,7 @@ class MegatronGenerate(Resource):
            top_k = request.get_json()["top_k"]
            if not (type(top_k) == int):
                return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
            if not (0 < top_k <= 1000):
            if not (0 <= top_k <= 1000):
                return "top_k must be equal to or greater than 0 and less than or equal to 1000"
        
        top_p = 0.0
@@ -90,7 +93,7 @@ class MegatronGenerate(Resource):
                return "top_p must be a positive float less than or equal to 1.0"
            if top_p > 0.0 and top_k > 0.0:
                return "cannot set both top-k and top-p samplings."
            if not (0 < top_p <= 1.0):
            if not (0 <= top_p <= 1.0):
                return "top_p must be less than or equal to 1.0"
        
        add_BOS = False