Commit c6e7c7fd authored by mshoeybi's avatar mshoeybi
Browse files

removed return all probs

parent 8d405805
Loading
Loading
Loading
Loading
+10 −19
Original line number Diff line number Diff line
@@ -31,7 +31,6 @@ def generate_and_post_process(model,
                              prompts=None,
                              tokens_to_generate=0,
                              return_output_log_probs=False,
                              return_all_log_probs=False,
                              greedy_sampling=False,
                              top_k_sampling=0,
                              top_p_sampling=0.0,
@@ -42,12 +41,11 @@ def generate_and_post_process(model,
    move to cpu and convert to list."""

    # Main inference.
    tokens, lengths, output_log_probs, all_log_probs = generate(
    tokens, lengths, output_log_probs = generate(
        model,
        prompts=prompts,
        tokens_to_generate=tokens_to_generate,
        return_output_log_probs=return_output_log_probs,
        return_all_log_probs=return_all_log_probs,
        greedy_sampling=greedy_sampling,
        top_k_sampling=top_k_sampling,
        top_p_sampling=top_p_sampling,
@@ -63,11 +61,9 @@ def generate_and_post_process(model,

        if return_output_log_probs:
            output_log_probs = output_log_probs.cpu().numpy().tolist()
        if return_all_log_probs:
            all_log_probs = all_log_probs.cpu().numpy().tolist()

        return prompts_plus_generations, prompts_plus_generations_segments, \
            output_log_probs, all_log_probs, tokens
            output_log_probs, tokens

    return None

@@ -77,7 +73,6 @@ def generate(model,
             prompts=None,
             tokens_to_generate=0,
             return_output_log_probs=False,
             return_all_log_probs=False,
             greedy_sampling=False,
             top_k_sampling=0,
             top_p_sampling=0.0,
@@ -90,24 +85,21 @@ def generate(model,
           discard tokens in the tokens tensor that are after the
           corresponding length.
       output_log_probs: log probs of the tokens.
       all_log_probs: full log probs for all of tokens.
    """

    # Make sure input params are avaialble to all ranks.
    values = [tokens_to_generate,
              return_output_log_probs, return_all_log_probs,
    values = [tokens_to_generate, return_output_log_probs,
              greedy_sampling, top_k_sampling, top_p_sampling,
              temperature, add_BOS, use_eod_token_for_early_termination]
    values_float_tensor = broadcast_float_list(9, float_list=values)
    values_float_tensor = broadcast_float_list(8, float_list=values)
    tokens_to_generate = int(values_float_tensor[0].item())
    return_output_log_probs = bool(values_float_tensor[1].item())
    return_all_log_probs = bool(values_float_tensor[2].item())
    greedy_sampling = bool(values_float_tensor[3].item())
    top_k_sampling = int(values_float_tensor[4].item())
    top_p_sampling = values_float_tensor[5].item()
    temperature = values_float_tensor[6].item()
    add_BOS = bool(values_float_tensor[7].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
    greedy_sampling = bool(values_float_tensor[2].item())
    top_k_sampling = int(values_float_tensor[3].item())
    top_p_sampling = values_float_tensor[4].item()
    temperature = values_float_tensor[5].item()
    add_BOS = bool(values_float_tensor[6].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[7].item())

    # Tokenize prompts and get the batch.
    # Note that these tensors are broadcaseted to all ranks.
@@ -122,7 +114,6 @@ def generate(model,
    return generate_tokens_probs_and_return_on_first_stage(
        model, context_tokens_tensor, context_length_tensor,
        return_output_log_probs=return_output_log_probs,
        return_all_log_probs=return_all_log_probs,
        greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
        temperature=temperature,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)
+35 −13
Original line number Diff line number Diff line
@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None):



def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

    if mpu.is_pipeline_last_stage():
def _is_cuda(tensor):
    """Check if a tensor is not none and is cuda."""
    assert tensor is not None
    assert tensor.is_cuda



def _is_cuda_contiguous(tensor):
    """Check if a tensor is not none, is cuda, and is contiguous."""
    _is_cuda(tensor)
    assert tensor.is_contiguous()



def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

    is_last_stage = mpu.is_pipeline_last_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if mpu.is_pipeline_first_stage() and is_last_stage:
        return tensor

    if is_last_stage:
        _is_cuda_contiguous(tensor)
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    # Only first and last stage pipeline stages need to be involved.
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return tensor
    # Only first and last stage pipeline stages need to be involved.
    if is_last_stage or is_first_stage:
        if is_last_stage:
            assert tensor is not None
            assert tensor.is_cuda
            assert tensor.is_contiguous()
            _is_cuda_contiguous(tensor)
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""

    # Only first and last stage pipeline stages need to be involved.
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return
    # Only first and last stage pipeline stages need to be involved.
    if is_last_stage or is_first_stage:
        assert tensor is not None
        assert tensor.is_cuda
        _is_cuda(tensor)
        is_contiguous = tensor.is_contiguous()
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
    """

    if torch.distributed.get_rank() == rank:
        assert tensor is not None
        assert tensor.is_cuda
        _is_cuda_contiguous(tensor)
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
+3 −29
Original line number Diff line number Diff line
@@ -31,7 +31,6 @@ from .sampling import sample
def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
        return_all_log_probs=False,
        greedy=False, top_k=0, top_p=0.0,
        temperature=1.0,
        use_eod_token_for_early_termination=True):
@@ -43,9 +42,6 @@ def generate_tokens_probs_and_return_on_first_stage(
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
            after logits are modifed for sampling.
        return_all_log_probs: flag to calculate the log probability of across
            all the tokens (vocab size). Note that the log probability is the
            one after logits are modifed for sampling.
        greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
            Note that these three paramters are exclusive meaning that:
                if greedy = true then we should have top-k=top-p=0.
@@ -62,8 +58,6 @@ def generate_tokens_probs_and_return_on_first_stage(
        generated_sequence_lengths: total length (including prompt) of
            the generated sequence. size: [b]
        output_log_probs: log probability of the selected tokens. size: [b, s]
        all_log_probs: log probability of all the tokens.
            size: [b, s, vocab-size]
    """

    args = get_args()
@@ -91,10 +85,6 @@ def generate_tokens_probs_and_return_on_first_stage(
    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    # Log probability of all tokens for the sequence.
    all_log_probs = None
    all_log_probs_size = (batch_size, max_sequence_length -1,
                          args.padded_vocab_size)
    # Lengths of generated seuquence including including prompts.
    generated_sequence_lengths = None
    if mpu.is_pipeline_last_stage():
@@ -102,10 +92,6 @@ def generate_tokens_probs_and_return_on_first_stage(
            output_log_probs = torch.empty(output_log_probs_size,
                                           dtype=torch.float32,
                                           device=torch.cuda.current_device())
        if return_all_log_probs:
            all_log_probs = torch.empty(all_log_probs_size,
                                        dtype=torch.float32,
                                        device=torch.cuda.current_device())
        generated_sequence_lengths = torch.ones(
            batch_size, dtype=torch.int64,
            device=torch.cuda.current_device()) * max_sequence_length
@@ -157,12 +143,8 @@ def generate_tokens_probs_and_return_on_first_stage(
                tokens[started, context_length] = new_sample[started]

                # Calculate the log probabilities.
                if return_output_log_probs or return_all_log_probs:
                if return_output_log_probs:
                    log_probs = F.log_softmax(logits, dim=2)
                    if return_all_log_probs:
                        all_log_probs[:,
                                      prev_context_length:context_length,
                                      :] = log_probs
                    if return_output_log_probs:
                        # Pick the tokens that we need to get the log
                        # probabilities for. Note that next input token is
@@ -208,8 +190,6 @@ def generate_tokens_probs_and_return_on_first_stage(
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length]
        if return_all_log_probs:
            all_log_probs = all_log_probs[:, :context_length, :]

    # ======================================
    # Broadcast to the first pipeline stage.
@@ -221,14 +201,8 @@ def generate_tokens_probs_and_return_on_first_stage(
        output_log_probs_size = (batch_size, context_length)
        output_log_probs = broadcast_from_last_to_first_pipeline_stage(
            output_log_probs_size, torch.float32, output_log_probs)
    if return_all_log_probs:
        all_log_probs_size = (batch_size, context_length,
                              args.padded_vocab_size)
        all_log_probs = broadcast_from_last_to_first_pipeline_stage(
            all_log_probs_size, torch.float32, all_log_probs)

    return tokens, generated_sequence_lengths, output_log_probs, \
        all_log_probs

    return tokens, generated_sequence_lengths, output_log_probs



+1 −2
Original line number Diff line number Diff line
@@ -101,13 +101,12 @@ class MegatronGenerate(Resource):

        with lock:  # Need to get lock to keep multiple threads from hitting code
            MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
            response, response_seg, response_logprobs, _, _ = \
            response, response_seg, response_logprobs, _ = \
                generate_and_post_process(
                    self.model,
                    prompts=prompts,
                    tokens_to_generate=tokens_to_generate,
                    return_output_log_probs=logprobs,
                    return_all_log_probs=False,
                    greedy_sampling=args.greedy,
                    top_k_sampling=top_k,
                    top_p_sampling=top_p,