Commit 8c119d80 authored by mshoeybi's avatar mshoeybi
Browse files

tested and working

parent a7539b0f
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
    parser = _add_biencoder_args(parser)
    parser = _add_vit_args(parser)
    parser = _add_logging_args(parser)
    parser = _add_inference_args(parser)

    # Custom arguments.
    if extra_args_provider is not None:
@@ -273,6 +274,18 @@ def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


def _add_inference_args(parser):
    group = parser.add_argument_group(title='inference')

    group.add_argument('--inference-batch-times-seqlen-threshold',
                       type=int, default=512,
                       help='During inference, if batch-size times '
                       'sequence-length is smaller than this threshold '
                       'then we will not use pipelining, otherwise we will.')

    return parser

    
def _add_network_size_args(parser):
    group = parser.add_argument_group(title='network size')

+41 −11
Original line number Diff line number Diff line
@@ -26,14 +26,20 @@ from .tokenization import (
    detokenize_generations)



def generate_and_post_process(model,
                              prompts=None,
                              tokens_to_generate=0,
                              return_output_log_probs=False,
                              return_all_log_probs=False,
                              greedy_sampling=False,
                              top_k_sampling=0,
                              top_p_sampling=0.0,
                              temperature=1.0,
                              add_BOS=False):
    """TO DO ..."""
                              add_BOS=False,
                              use_eod_token_for_early_termination=True):
    """Run inferecne and post-process outputs, i.e., detokenize,
    move to cpu and convert to list."""

    # Main inference.
    tokens, lengths, output_log_probs, all_log_probs = generate(
@@ -42,8 +48,12 @@ def generate_and_post_process(model,
        tokens_to_generate=tokens_to_generate,
        return_output_log_probs=return_output_log_probs,
        return_all_log_probs=return_all_log_probs,
        greedy_sampling=greedy_sampling,
        top_k_sampling=top_k_sampling,
        top_p_sampling=top_p_sampling,
        temperature=temperature,
        add_BOS=add_BOS)
        add_BOS=add_BOS,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)

    # Only post-process on first stage.
    if mpu.is_pipeline_first_stage():
@@ -62,24 +72,42 @@ def generate_and_post_process(model,
    return None



def generate(model,
             prompts=None,
             tokens_to_generate=0,
             return_output_log_probs=False,
             return_all_log_probs=False,
             greedy_sampling=False,
             top_k_sampling=0,
             top_p_sampling=0.0,
             temperature=1.0,
             add_BOS=False):
    """TO DO ..."""
             add_BOS=False,
             use_eod_token_for_early_termination=True):
    """Given prompts and input parameters, run inference and return:
       tokens: prompts plus the generated tokens.
       lengths: length of the prompt + generations. Note that we can
           discard tokens in the tokens tensor that are after the
           corresponding length.
       output_log_probs: log probs of the tokens.
       all_log_probs: full log probs for all of tokens.
    """

    # Make sure input params are avaialble to all ranks.
    values = [tokens_to_generate, return_output_log_probs,
              return_all_log_probs, temperature, add_BOS]
    values_float_tensor = broadcast_float_list(5, float_list=values)
    values = [tokens_to_generate,
              return_output_log_probs, return_all_log_probs,
              greedy_sampling, top_k_sampling, top_p_sampling,
              temperature, add_BOS, use_eod_token_for_early_termination]
    values_float_tensor = broadcast_float_list(9, float_list=values)
    tokens_to_generate = int(values_float_tensor[0].item())
    return_output_log_probs = bool(values_float_tensor[1].item())
    return_all_log_probs = bool(values_float_tensor[2].item())
    temperature = values_float_tensor[3].item()
    add_BOS = bool(values_float_tensor[4].item())
    greedy_sampling = bool(values_float_tensor[3].item())
    top_k_sampling = int(values_float_tensor[4].item())
    top_p_sampling = values_float_tensor[5].item()
    temperature = values_float_tensor[6].item()
    add_BOS = bool(values_float_tensor[7].item())
    use_eod_token_for_early_termination = bool(values_float_tensor[8].item())

    # Tokenize prompts and get the batch.
    # Note that these tensors are broadcaseted to all ranks.
@@ -95,4 +123,6 @@ def generate(model,
        model, context_tokens_tensor, context_length_tensor,
        return_output_log_probs=return_output_log_probs,
        return_all_log_probs=return_all_log_probs,
        temperature=temperature)
        greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
        temperature=temperature,
        use_eod_token_for_early_termination=use_eod_token_for_early_termination)
+36 −0
Original line number Diff line number Diff line
@@ -21,6 +21,38 @@ import torch
from megatron import mpu



def recv_from_prev_pipeline_rank_(recv_buffer=None):
    """Receive from previous pipeline stage and update the
    input buffer inplace."""
    if not mpu.is_pipeline_first_stage():
        assert recv_buffer is not None
        recv_prev_op = torch.distributed.P2POp(
            torch.distributed.irecv, recv_buffer,
            mpu.get_pipeline_model_parallel_prev_rank())
        reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



def send_to_next_pipeline_rank(tensor=None):
    """Send output to the next pipeline stage."""
    if not mpu.is_pipeline_last_stage():
        assert tensor is not None
        send_next_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor,
            mpu.get_pipeline_model_parallel_next_rank())
        reqs = torch.distributed.batch_isend_irecv([send_next_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()



def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
    """Broadcast a tensor from last pipeline stage to all ranks."""

@@ -96,6 +128,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
            tensor[...] = tensor_



def broadcast_tensor(size, dtype, tensor=None, rank=0):
    """ Given size and type of a tensor on all ranks and the tensor value
        only on a specific rank, broadcast from that rank to all other ranks.
@@ -114,6 +147,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
    return tensor



def broadcast_list(size, dtype, list_values=None, rank=0):
    """Broadcast a list of values with a given type."""

@@ -125,12 +159,14 @@ def broadcast_list(size, dtype, list_values=None, rank=0):
    return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)



def broadcast_int_list(size, int_list=None, rank=0):
    """Broadcast a list of interger values."""

    return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)



def broadcast_float_list(size, float_list=None, rank=0):
    """Broadcast a list of float values."""

+47 −32
Original line number Diff line number Diff line
@@ -22,14 +22,20 @@ import torch
from megatron import (
    get_args,
    mpu)
from .communication import (
    send_to_next_pipeline_rank,
    recv_from_prev_pipeline_rank_)



class InferenceParams:

    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""

    def __init__(self, max_batch_size, max_sequence_len):

        """Note that offsets are set to zero and we always set the
        flag to allocate memory. After the first call, make sure to
        set this flag to False."""
        self.max_sequence_len = max_sequence_len
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
@@ -39,34 +45,46 @@ class InferenceParams:


class ForwardStep:
    """Forward step function with all the communications.
    We use a class here to hide the inference parameters
    from the outside caller."""

    def __init__(self, model, max_batch_size, max_sequence_len):

        """Set values so we don't need to do it multiple times."""
        # Make sure model is in eval mode.
        if isinstance(model, Iterable):
            for this_model in model:
                this_model.eval()
        else:
        assert not isinstance(model, Iterable), \
            'interleaving schedule is not supported for inference'
        model.eval()
        self.model = model

        self.constant = 512

        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
                                                max_sequence_len)
        # Pipelining arguments.
        args = get_args()
        self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size
        # Threshold of pipelining.
        self.pipelining_batch_x_seqlen = \
            args.inference_batch_times_seqlen_threshold


    def __call__(self, tokens, position_ids, attention_mask):
        if tokens.size(0) * tokens.size(1) >= self.constant:
            micro_batch_size = max(1, self.constant // tokens.size(1))
            return _with_pipelining_forward_step(self.model, tokens,
        """Invocation of the forward methods. Note that self.inference_params
        is being modified by the forward step."""
        # Pipelining case.
        if self.pipeline_size_larger_than_one:
            current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
            if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
                micro_batch_size = \
                    max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
                return _with_pipelining_forward_step(self.model,
                                                     tokens,
                                                     position_ids,
                                                     attention_mask,
                                                     self.inference_params,
                                                     micro_batch_size)
        else:
            return _no_pipelining_forward_step(self.model, tokens,

        return _no_pipelining_forward_step(self.model,
                                           tokens,
                                           position_ids,
                                           attention_mask,
                                           self.inference_params)
@@ -103,9 +121,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
        recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)

    # Receive from previous stage.
    if not mpu.is_pipeline_first_stage():
        torch.distributed.recv(recv_buffer,
                               src=mpu.get_pipeline_model_parallel_prev_rank())
    recv_from_prev_pipeline_rank_(recv_buffer)

    # Forward pass through the model.
    model.set_input_tensor(recv_buffer)
@@ -113,9 +129,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
                          inference_params=inference_params)

    # Send output to the next stage.
    if not mpu.is_pipeline_last_stage():
        torch.distributed.send(output_tensor,
                               mpu.get_pipeline_model_parallel_next_rank())
    send_to_next_pipeline_rank(output_tensor)

    # Make sure we do not allocate context memory anymore.
    if inference_params.allocate_key_value_memory:
@@ -128,7 +142,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,

def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                inference_params, recv_buffer=None):

    """If recv_buffer is none, we will allocate one on the fly."""
    # Run a simple forward pass.
    output_tensor = _forward_step_helper(model, tokens, position_ids,
                                         attention_mask, inference_params,
@@ -143,9 +157,10 @@ def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
    return logits



def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):

    """No interleaving is supported."""
    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)

+18 −10
Original line number Diff line number Diff line
@@ -32,10 +32,12 @@ def generate_tokens_probs_and_return_on_first_stage(
        model, tokens, lengths,
        return_output_log_probs=False,
        return_all_log_probs=False,
        temperature=1.0):
        greedy=False, top_k=0, top_p=0.0,
        temperature=1.0,
        use_eod_token_for_early_termination=True):
    """Main token generation function.
    Arguments:
        model: XXX
        model: no interleaving is supported.
        tokens: prompt tokens extended to be of size [b, max-sequence-length]
        lengths: original prompt length, size: [b]
        return_output_log_probs: flag to calculate the log probability of
@@ -44,7 +46,14 @@ def generate_tokens_probs_and_return_on_first_stage(
        return_all_log_probs: flag to calculate the log probability of across
            all the tokens (vocab size). Note that the log probability is the
            one after logits are modifed for sampling.
        greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
            Note that these three paramters are exclusive meaning that:
                if greedy = true then we should have top-k=top-p=0.
                if top-k > 0 then we expect greedy=false and top-p=0.
                if top-p > 0 then we check for greedy=false and top-k=0.
        temperature: sampling temperature.
        use_eod_token_for_early_termination: if True, do early termination if
            all the sequences have reached this token.
    Note: Outside of model, other parameters only need to be available on
          rank 0.
    Outputs: Note that is size is adjusted to a lower value than
@@ -108,10 +117,9 @@ def generate_tokens_probs_and_return_on_first_stage(
    # Run infernece
    # =============

    with torch.no_grad():
        attention_mask, position_ids = _build_attention_mask_and_position_ids(
            tokens)

    with torch.no_grad():
        prev_context_length = 0
        for context_length in range(min_prompt_length, max_sequence_length):

@@ -132,9 +140,9 @@ def generate_tokens_probs_and_return_on_first_stage(
                last_token_logits = logits[:, -1, :]
                new_sample, updated_last_token_logits = sample(
                    last_token_logits,
                    greedy=args.greedy,
                    top_k=args.top_k,
                    top_p=args.top_p,
                    greedy=greedy,
                    top_k=top_k,
                    top_p=top_p,
                    temperature=temperature,
                    vocab_size=tokenizer.vocab_size)
                # Now that we have the sample and updated logits,
@@ -189,8 +197,8 @@ def generate_tokens_probs_and_return_on_first_stage(
                done = torch.all(is_generation_done)
            done = broadcast_from_last_pipeline_stage(1, torch.uint8,
                                                      tensor=done)
            #if done:
            #    break
            if use_eod_token_for_early_termination and done:
                break

    # ===================================================
    # Update the length of based on max generated length.