Commit 9cc286ba authored by rprenger's avatar rprenger
Browse files

Getting tokens_to_generate=0 to work

parent 21d2b0fe
Loading
Loading
Loading
Loading
+21 −23
Original line number Diff line number Diff line
@@ -20,7 +20,9 @@ import torch

from megatron import mpu
from .communication import broadcast_float_list
from .generation import generate_tokens_probs_and_return_on_first_stage
from .generation import (
        generate_tokens_probs_and_return_on_first_stage,
        score_and_return_on_first_stage)
from .tokenization import (
    tokenize_prompts,
    detokenize_generations)
@@ -31,7 +33,6 @@ def generate_and_post_process(model,
                              prompts=None,
                              tokens_to_generate=0,
                              return_output_log_probs=False,
                              return_all_log_probs=False,
                              greedy_sampling=False,
                              top_k_sampling=0,
                              top_p_sampling=0.0,
@@ -43,12 +44,11 @@ def generate_and_post_process(model,
    move to cpu and convert to list."""

    # Main inference.
    tokens, lengths, output_log_probs, all_log_probs = generate(
    tokens, lengths, output_log_probs = generate(
        model,
        prompts=prompts,
        tokens_to_generate=tokens_to_generate,
        return_output_log_probs=return_output_log_probs,
        return_all_log_probs=return_all_log_probs,
        greedy_sampling=greedy_sampling,
        top_k_sampling=top_k_sampling,
        top_p_sampling=top_p_sampling,
@@ -59,17 +59,16 @@ def generate_and_post_process(model,

    # Only post-process on first stage.
    if mpu.is_pipeline_first_stage():

        tokens, prompts_plus_generations, prompts_plus_generations_segments = \
            detokenize_generations(tokens, lengths, True)

        if return_output_log_probs:
            output_log_probs = output_log_probs.cpu().numpy().tolist()
        if return_all_log_probs:
            all_log_probs = all_log_probs.cpu().numpy().tolist()
            for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
                output_log_probs[i] = prob[:len(seg)-1]

        return prompts_plus_generations, prompts_plus_generations_segments, \
            output_log_probs, all_log_probs, tokens
            output_log_probs, tokens

    return None

@@ -79,7 +78,6 @@ def generate(model,
             prompts=None,
             tokens_to_generate=0,
             return_output_log_probs=False,
             return_all_log_probs=False,
             greedy_sampling=False,
             top_k_sampling=0,
             top_p_sampling=0.0,
@@ -93,25 +91,23 @@ def generate(model,
           discard tokens in the tokens tensor that are after the
           corresponding length.
       output_log_probs: log probs of the tokens.
       all_log_probs: full log probs for all of tokens.
    """

    # Make sure input params are avaialble to all ranks.
    values = [tokens_to_generate,
              return_output_log_probs, return_all_log_probs,
              return_output_log_probs,
              greedy_sampling, top_k_sampling, top_p_sampling,
              temperature, add_BOS, use_eod_token_for_early_termination, just_score]
    values_float_tensor = broadcast_float_list(10, float_list=values)
    values_float_tensor = broadcast_float_list(9, float_list=values)
    tokens_to_generate = int(values_float_tensor[0].item())
    return_output_log_probs = bool(values_float_tensor[1].item())
    return_all_log_probs = bool(values_float_tensor[2].item())
    greedy_sampling = bool(values_float_tensor[3].item())
    top_k_sampling = int(values_float_tensor[4].item())
    top_p_sampling = values_float_tensor[5].item()
    temperature = values_float_tensor[6].item()
    add_BOS = bool(values_float_tensor[7].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
    just_score = bool(values_float_tensor[9].item())
    greedy_sampling = bool(values_float_tensor[2].item())
    top_k_sampling = int(values_float_tensor[3].item())
    top_p_sampling = values_float_tensor[4].item()
    temperature = values_float_tensor[5].item()
    add_BOS = bool(values_float_tensor[6].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[7].item())
    just_score = bool(values_float_tensor[8].item())

    # Tokenize prompts and get the batch.
    # Note that these tensors are broadcaseted to all ranks.
@@ -121,13 +117,15 @@ def generate(model,
    context_tokens_tensor, context_length_tensor = tokenize_prompts(
        prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)

    if just_score:
        return score_and_return_on_first_stage(
            model, context_tokens_tensor, context_length_tensor)

    # Main inference function.
    # Note that the outputs are available on the first stage.
    return generate_tokens_probs_and_return_on_first_stage(
        model, context_tokens_tensor, context_length_tensor,
        return_output_log_probs=return_output_log_probs,
        return_all_log_probs=return_all_log_probs,
        greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
        temperature=temperature,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination,
        just_score=just_score)
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)
+68 −30
Original line number Diff line number Diff line
@@ -27,15 +27,76 @@ from .communication import (
from .forward_step import ForwardStep
from .sampling import sample

def score_and_return_on_first_stage(model, tokens, lengths):
    """Function for just scoring.
    Arguments:
        model: no interleaving is supported.
        tokens: prompt tokens extended to be of size [b, max_prompt_length]
        lengths: original prompt length, size: [b]
    Note: Outside of model, other parameters only need to be available on
          rank 0.
    Outputs: 
        output_log_probs: log probability of the selected tokens. size: [b, s]
    """

    args = get_args()

    batch_size = tokens.size(0)
    max_prompt_length = lengths.max().item()
    assert max_prompt_length == tokens.size(1)
    max_sequence_length = min(max_prompt_length, args.max_position_embeddings)

    # forward step.
    forward_step = ForwardStep(model, batch_size, max_sequence_length)

    # ===================
    # Pre-allocate memory
    # ===================

    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    
    if mpu.is_pipeline_last_stage():
        output_log_probs = torch.empty(output_log_probs_size,
                                       dtype=torch.float32,
                                       device=torch.cuda.current_device())
    
    # =============
    # Run infernece
    # =============
    with torch.no_grad():
        attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
        
        # logits will be meanigful only in the last pipeline stage.
        logits = forward_step(tokens, position_ids, attention_mask)

        if mpu.is_pipeline_last_stage():
            # Always the last stage should have an output.
            assert logits is not None
            log_probs = F.log_softmax(logits, dim=2)
            
            # Pick the tokens that we need to get the log
            # probabilities for. Note that next input token is
            # the token which we selected in the current logits,
            # so shift by 1.
            indices = torch.unsqueeze(tokens[:, 1:], 2)
            output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
    
    # ======================================
    # Broadcast to the first pipeline stage.
    # ======================================
    output_log_probs = broadcast_from_last_to_first_pipeline_stage(
        output_log_probs_size, torch.float32, output_log_probs)
    
    return tokens, lengths, output_log_probs

def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
        return_all_log_probs=False,
        greedy=False, top_k=0, top_p=0.0,
        temperature=1.0,
        use_eod_token_for_early_termination=True,
        just_score=False):
        use_eod_token_for_early_termination=True):
    """Main token generation function.
    Arguments:
        model: no interleaving is supported.
@@ -44,9 +105,6 @@ def generate_tokens_probs_and_return_on_first_stage(
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
            after logits are modifed for sampling.
        return_all_log_probs: flag to calculate the log probability of across
            all the tokens (vocab size). Note that the log probability is the
            one after logits are modifed for sampling.
        greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
            Note that these three paramters are exclusive meaning that:
                if greedy = true then we should have top-k=top-p=0.
@@ -63,8 +121,6 @@ def generate_tokens_probs_and_return_on_first_stage(
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
        all_log_probs: log probability of all the tokens.
            size: [b, s, vocab-size]
    """

    args = get_args()
@@ -93,9 +149,7 @@ def generate_tokens_probs_and_return_on_first_stage(
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    # Log probability of all tokens for the sequence.
    all_log_probs = None
    all_log_probs_size = (batch_size, max_sequence_length -1,
                          args.padded_vocab_size)
    
    # Lengths of generated seuquence including including prompts.
    generated_sequence_lengths = None
    if mpu.is_pipeline_last_stage():
@@ -103,10 +157,6 @@ def generate_tokens_probs_and_return_on_first_stage(
            output_log_probs = torch.empty(output_log_probs_size,
                                           dtype=torch.float32,
                                           device=torch.cuda.current_device())
        if return_all_log_probs:
            all_log_probs = torch.empty(all_log_probs_size,
                                        dtype=torch.float32,
                                        device=torch.cuda.current_device())
        generated_sequence_lengths = torch.ones(
                batch_size, dtype=torch.int64,
                device=torch.cuda.current_device()) * max_sequence_length
@@ -159,12 +209,8 @@ def generate_tokens_probs_and_return_on_first_stage(
                tokens[started, context_length] = new_sample[started]

                # Calculate the log probabilities.
                if return_output_log_probs or return_all_log_probs:
                if return_output_log_probs:
                    log_probs = F.log_softmax(logits, dim=2)
                    if return_all_log_probs:
                        all_log_probs[:,
                                      prev_context_length:context_length,
                                      :] = log_probs
                    if return_output_log_probs:
                        # Pick the tokens that we need to get the log
                        # probabilities for. Note that next input token is
@@ -210,8 +256,6 @@ def generate_tokens_probs_and_return_on_first_stage(
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length].contiguous()
        if return_all_log_probs:
            all_log_probs = all_log_probs[:, :context_length, :]

    # ======================================
    # Broadcast to the first pipeline stage.
@@ -223,14 +267,8 @@ def generate_tokens_probs_and_return_on_first_stage(
        output_log_probs_size = (batch_size, context_length)
        output_log_probs = broadcast_from_last_to_first_pipeline_stage(
            output_log_probs_size, torch.float32, output_log_probs)
    if return_all_log_probs:
        all_log_probs_size = (batch_size, context_length,
                              args.padded_vocab_size)
        all_log_probs = broadcast_from_last_to_first_pipeline_stage(
            all_log_probs_size, torch.float32, all_log_probs)

    return tokens, generated_sequence_lengths, output_log_probs, \
        all_log_probs

    return tokens, generated_sequence_lengths, output_log_probs



+5 −4
Original line number Diff line number Diff line
@@ -69,6 +69,8 @@ class MegatronGenerate(Resource):
            logprobs = request.get_json()["logprobs"]
            if not isinstance(logprobs, bool):
                return "logprobs must be a boolean value"
            if just_score and not logprobs:
                return "tokens_to_generate=0 implies logprobs=True"
        
        temperature = 1.0
        if "temperature" in request.get_json():
@@ -83,7 +85,7 @@ class MegatronGenerate(Resource):
            top_k = request.get_json()["top_k"]
            if not (type(top_k) == int):
                return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
            if not (0 < top_k <= 1000):
            if not (0 <= top_k <= 1000):
                return "top_k must be equal to or greater than 0 and less than or equal to 1000"
        
        top_p = 0.0
@@ -93,7 +95,7 @@ class MegatronGenerate(Resource):
                return "top_p must be a positive float less than or equal to 1.0"
            if top_p > 0.0 and top_k > 0.0:
                return "cannot set both top-k and top-p samplings."
            if not (0 < top_p <= 1.0):
            if not (0 <= top_p <= 1.0):
                return "top_p must be less than or equal to 1.0"
        
        add_BOS = False
@@ -104,13 +106,12 @@ class MegatronGenerate(Resource):

        with lock:  # Need to get lock to keep multiple threads from hitting code
            MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
            response, response_seg, response_logprobs, _, _ = \
            response, response_seg, response_logprobs, _ = \
                generate_and_post_process(
                    self.model,
                    prompts=prompts,
                    tokens_to_generate=tokens_to_generate,
                    return_output_log_probs=logprobs,
                    return_all_log_probs=False,
                    greedy_sampling=args.greedy,
                    top_k_sampling=top_k,
                    top_p_sampling=top_p,