Commit 21d2b0fe authored by rprenger's avatar rprenger
Browse files

Allowing for a 0 tokens/just scoring mode

parent 8d405805
Loading
Loading
Loading
Loading
+12 −7
Original line number Diff line number Diff line
@@ -37,7 +37,8 @@ def generate_and_post_process(model,
                              top_p_sampling=0.0,
                              temperature=1.0,
                              add_BOS=False,
                              use_eod_token_for_early_termination=True):
                              use_eod_token_for_early_termination=True,
                              just_score=False):
    """Run inference and post-process outputs, i.e., detokenize,
    move to cpu and convert to list."""

@@ -53,7 +54,8 @@ def generate_and_post_process(model,
        top_p_sampling=top_p_sampling,
        temperature=temperature,
        add_BOS=add_BOS,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)
        use_eod_token_for_early_termination=use_eod_token_for_early_termination,
        just_score=just_score)

    # Only post-process on first stage.
    if mpu.is_pipeline_first_stage():
@@ -83,7 +85,8 @@ def generate(model,
             top_p_sampling=0.0,
             temperature=1.0,
             add_BOS=False,
             use_eod_token_for_early_termination=True):
             use_eod_token_for_early_termination=True,
             just_score=False):
    """Given prompts and input parameters, run inference and return:
       tokens: prompts plus the generated tokens.
       lengths: length of the prompt + generations. Note that we can
@@ -97,8 +100,8 @@ def generate(model,
    values = [tokens_to_generate,
              return_output_log_probs, return_all_log_probs,
              greedy_sampling, top_k_sampling, top_p_sampling,
              temperature, add_BOS, use_eod_token_for_early_termination]
    values_float_tensor = broadcast_float_list(9, float_list=values)
              temperature, add_BOS, use_eod_token_for_early_termination, just_score]
    values_float_tensor = broadcast_float_list(10, float_list=values)
    tokens_to_generate = int(values_float_tensor[0].item())
    return_output_log_probs = bool(values_float_tensor[1].item())
    return_all_log_probs = bool(values_float_tensor[2].item())
@@ -108,12 +111,13 @@ def generate(model,
    temperature = values_float_tensor[6].item()
    add_BOS = bool(values_float_tensor[7].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
    just_score = bool(values_float_tensor[9].item())

    # Tokenize prompts and get the batch.
    # Note that these tensors are broadcaseted to all ranks.
    if torch.distributed.get_rank() == 0:
        assert prompts is not None
        assert tokens_to_generate > 0
        #assert tokens_to_generate > 0
    context_tokens_tensor, context_length_tensor = tokenize_prompts(
        prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)

@@ -125,4 +129,5 @@ def generate(model,
        return_all_log_probs=return_all_log_probs,
        greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
        temperature=temperature,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)
        use_eod_token_for_early_termination=use_eod_token_for_early_termination,
        just_score=just_score)
+6 −4
Original line number Diff line number Diff line
@@ -34,7 +34,8 @@ def generate_tokens_probs_and_return_on_first_stage(
        return_all_log_probs=False,
        greedy=False, top_k=0, top_p=0.0,
        temperature=1.0,
        use_eod_token_for_early_termination=True):
        use_eod_token_for_early_termination=True,
        just_score=False):
    """Main token generation function.
    Arguments:
        model: no interleaving is supported.
@@ -109,6 +110,7 @@ def generate_tokens_probs_and_return_on_first_stage(
        generated_sequence_lengths = torch.ones(
                batch_size, dtype=torch.int64,
                device=torch.cuda.current_device()) * max_sequence_length
    
    # Whether we have reached a termination id.
    is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
                                     device=torch.cuda.current_device())
@@ -207,7 +209,7 @@ def generate_tokens_probs_and_return_on_first_stage(
    tokens = tokens[:, :(context_length + 1)]
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length]
            output_log_probs = output_log_probs[:, :context_length].contiguous()
        if return_all_log_probs:
            all_log_probs = all_log_probs[:, :context_length, :]

+7 −3
Original line number Diff line number Diff line
@@ -54,12 +54,15 @@ class MegatronGenerate(Resource):
            return "Maximum number of prompts is 128", 400

        tokens_to_generate = 64  # Choosing hopefully sane default.  Full sequence is slow
        just_score=False
        if "tokens_to_generate" in request.get_json():
            tokens_to_generate = request.get_json()["tokens_to_generate"]
            if not isinstance(tokens_to_generate, int):
                return "tokens_to_generate must be an integer greater than 0"
            if tokens_to_generate < 1:
                return "tokens_to_generate must be an integer greater than 0"
            if tokens_to_generate < 0:
                return "tokens_to_generate must be an integer greater than or equal to 0"
            if tokens_to_generate == 0:
                just_score = True

        logprobs = False
        if "logprobs" in request.get_json():
@@ -113,7 +116,8 @@ class MegatronGenerate(Resource):
                    top_p_sampling=top_p,
                    temperature=temperature,
                    add_BOS=add_BOS,
                    use_eod_token_for_early_termination=True)
                    use_eod_token_for_early_termination=True,
                    just_score=just_score)
        
        return jsonify({"text": response,
            "segments": response_seg,