Commit 803ae5ee authored by rprenger's avatar rprenger
Browse files

cleaning up merge conflicts

parents 9cc286ba 0b0e37f0
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They

## GPT Text Generation

We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`, `top-p`, and `greedy`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.

Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.

+9 −12
Original line number Diff line number Diff line
@@ -33,7 +33,6 @@ def generate_and_post_process(model,
                              prompts=None,
                              tokens_to_generate=0,
                              return_output_log_probs=False,
                              greedy_sampling=False,
                              top_k_sampling=0,
                              top_p_sampling=0.0,
                              temperature=1.0,
@@ -49,7 +48,6 @@ def generate_and_post_process(model,
        prompts=prompts,
        tokens_to_generate=tokens_to_generate,
        return_output_log_probs=return_output_log_probs,
        greedy_sampling=greedy_sampling,
        top_k_sampling=top_k_sampling,
        top_p_sampling=top_p_sampling,
        temperature=temperature,
@@ -78,7 +76,6 @@ def generate(model,
             prompts=None,
             tokens_to_generate=0,
             return_output_log_probs=False,
             greedy_sampling=False,
             top_k_sampling=0,
             top_p_sampling=0.0,
             temperature=1.0,
@@ -98,16 +95,15 @@ def generate(model,
              return_output_log_probs,
              greedy_sampling, top_k_sampling, top_p_sampling,
              temperature, add_BOS, use_eod_token_for_early_termination, just_score]
    values_float_tensor = broadcast_float_list(9, float_list=values)
    values_float_tensor = broadcast_float_list(8, float_list=values)
    tokens_to_generate = int(values_float_tensor[0].item())
    return_output_log_probs = bool(values_float_tensor[1].item())
    greedy_sampling = bool(values_float_tensor[2].item())
    top_k_sampling = int(values_float_tensor[3].item())
    top_p_sampling = values_float_tensor[4].item()
    temperature = values_float_tensor[5].item()
    add_BOS = bool(values_float_tensor[6].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[7].item())
    just_score = bool(values_float_tensor[8].item())
    top_k_sampling = int(values_float_tensor[2].item())
    top_p_sampling = values_float_tensor[3].item()
    temperature = values_float_tensor[4].item()
    add_BOS = bool(values_float_tensor[5].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
    just_score = bool(values_float_tensor[7].item())

    # Tokenize prompts and get the batch.
    # Note that these tensors are broadcaseted to all ranks.
@@ -126,6 +122,7 @@ def generate(model,
    return generate_tokens_probs_and_return_on_first_stage(
        model, context_tokens_tensor, context_length_tensor,
        return_output_log_probs=return_output_log_probs,
        greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
        top_k=top_k_sampling,
        top_p=top_p_sampling,
        temperature=temperature,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)
+35 −13
Original line number Diff line number Diff line
@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None):



def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

    if mpu.is_pipeline_last_stage():
def _is_cuda(tensor):
    """Check if a tensor is not none and is cuda."""
    assert tensor is not None
    assert tensor.is_cuda



def _is_cuda_contiguous(tensor):
    """Check if a tensor is not none, is cuda, and is contiguous."""
    _is_cuda(tensor)
    assert tensor.is_contiguous()



def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

    is_last_stage = mpu.is_pipeline_last_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if mpu.is_pipeline_first_stage() and is_last_stage:
        return tensor

    if is_last_stage:
        _is_cuda_contiguous(tensor)
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    # Only first and last stage pipeline stages need to be involved.
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return tensor
    # Only first and last stage pipeline stages need to be involved.
    if is_last_stage or is_first_stage:
        if is_last_stage:
            assert tensor is not None
            assert tensor.is_cuda
            assert tensor.is_contiguous()
            _is_cuda_contiguous(tensor)
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""

    # Only first and last stage pipeline stages need to be involved.
    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return
    # Only first and last stage pipeline stages need to be involved.
    if is_last_stage or is_first_stage:
        assert tensor is not None
        assert tensor.is_cuda
        _is_cuda(tensor)
        is_contiguous = tensor.is_contiguous()
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
    """

    if torch.distributed.get_rank() == rank:
        assert tensor is not None
        assert tensor.is_cuda
        _is_cuda_contiguous(tensor)
    else:
        tensor = torch.empty(size,
                             dtype=dtype,
+14 −23
Original line number Diff line number Diff line
@@ -94,7 +94,7 @@ def score_and_return_on_first_stage(model, tokens, lengths):
def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
        greedy=False, top_k=0, top_p=0.0,
        top_k=0, top_p=0.0,
        temperature=1.0,
        use_eod_token_for_early_termination=True):
    """Main token generation function.
@@ -104,12 +104,12 @@ def generate_tokens_probs_and_return_on_first_stage(
        lengths: original prompt length, size: [b]
        return_output_log_probs: flag to calculate the log probability of
            the generated tokens. Note that the log probability is the one
            after logits are modifed for sampling.
        greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
            Note that these three paramters are exclusive meaning that:
                if greedy = true then we should have top-k=top-p=0.
                if top-k > 0 then we expect greedy=false and top-p=0.
                if top-p > 0 then we check for greedy=false and top-k=0.
            from the original logit.
        top_k, top_p: top-k and top-p sampling parameters.
            Note that top-k = 1 is gready. Also, these paramters are
            exclusive meaning that:
                if top-k > 0 then we expect top-p=0.
                if top-p > 0 then we check for top-k=0.
        temperature: sampling temperature.
        use_eod_token_for_early_termination: if True, do early termination if
            all the sequences have reached this token.
@@ -148,8 +148,6 @@ def generate_tokens_probs_and_return_on_first_stage(
    # Log probability of the sequence (prompt + generated tokens).
    output_log_probs = None
    output_log_probs_size = (batch_size, max_sequence_length - 1)
    # Log probability of all tokens for the sequence.
    
    # Lengths of generated seuquence including including prompts.
    generated_sequence_lengths = None
    if mpu.is_pipeline_last_stage():
@@ -190,22 +188,15 @@ def generate_tokens_probs_and_return_on_first_stage(

                # Sample.
                last_token_logits = logits[:, -1, :]
                new_sample, updated_last_token_logits = sample(
                    last_token_logits,
                    greedy=greedy,
                new_sample = sample(last_token_logits,
                                    top_k=top_k,
                                    top_p=top_p,
                                    temperature=temperature,
                                    vocab_size=tokenizer.vocab_size)
                # Now that we have the sample and updated logits,
                # update the main logits and input tokens.
                # If a prompt length is smaller or equal th current context
                # length, it means we have started generating tokens
                started = lengths <= context_length
                # Update the logits
                last_token_logits.masked_scatter_(
                    started.unsqueeze(1), updated_last_token_logits[started])
                # and the tokens.
                # Update the tokens.
                tokens[started, context_length] = new_sample[started]

                # Calculate the log probabilities.
@@ -255,7 +246,7 @@ def generate_tokens_probs_and_return_on_first_stage(
    tokens = tokens[:, :(context_length + 1)]
    if mpu.is_pipeline_last_stage():
        if return_output_log_probs:
            output_log_probs = output_log_probs[:, :context_length].contiguous()
            output_log_probs = output_log_probs[:, :context_length]

    # ======================================
    # Broadcast to the first pipeline stage.
+8 −9
Original line number Diff line number Diff line
@@ -55,8 +55,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):



def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
           vocab_size=None):
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
    """ Sample and generate a token.
    Note: logits has the dimension [b, v] where b is the batch size
          and v is the vocabulary size.
@@ -70,21 +69,21 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
    assert logits.type() == 'torch.cuda.FloatTensor', \
        'input logits should be floats.'

    # Clone so we do not modify the inputs,
    logits = logits.clone()

    # Greedy is just simple argmax.
    if greedy:
        assert top_k == 0, 'cannot set both greedy and top-k samplings.'
    if top_k == 1:
        assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
        samples = torch.argmax(logits, dim=-1)

    # Top-k or top-p sampling.
    else:
        # Clone so we do not modify the inputs,
        logits = logits.clone()
        # Apply temperature in place.
        if temperature != 1.0:
            logits.div_(temperature)

        if top_k > 0:
        if top_k > 1:
            assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
            assert top_k <= logits.size(1), 'top-k is larger than logit size.'
            if vocab_size:
@@ -104,4 +103,4 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
    if vocab_size:
        samples = torch.clamp(samples, min=0, max=(vocab_size - 1))

    return samples, logits
    return samples
Loading