Commit a7539b0f authored by mshoeybi's avatar mshoeybi
Browse files

pipelining works

parent 8f160844
Loading
Loading
Loading
Loading
+126 −122
Original line number Diff line number Diff line
@@ -15,8 +15,6 @@

"""Forward step utilities."""

from abc import ABC
from abc import abstractmethod
from collections.abc import Iterable

import torch
@@ -24,44 +22,27 @@ import torch
from megatron import (
    get_args,
    mpu)
from megatron.p2p_communication import (
    recv_forward,
    send_forward)


def forward_step_provider(model,
                          batch_size,
                          micro_batch_size,
                          max_sequence_len):

    args = get_args()

    if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size:
        return NoPipeliningForwardStep(model, batch_size, max_sequence_len)

    return SimplePipeliningForwardStep(model, batch_size,
                                       micro_batch_size,
                                       max_sequence_len)



class InferenceParams:

    def __init__(self, micro_batch_size_list, max_sequence_len):

        assert isinstance(micro_batch_size_list, list)
        assert max_sequence_len > 0
    def __init__(self, max_batch_size, max_sequence_len):

        self.micro_batch_size_list = micro_batch_size_list
        self.max_sequence_len = max_sequence_len
        self.max_batch_size = max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.allocate_key_value_memory = True
        self.micro_batch_index = 0


class ForwardStepBase(ABC):

    def __init__(self, model):
class ForwardStep:

    def __init__(self, model, max_batch_size, max_sequence_len):

        # Make sure model is in eval mode.
        if isinstance(model, Iterable):
            for this_model in model:
                this_model.eval()
@@ -69,125 +50,148 @@ class ForwardStepBase(ABC):
            model.eval()
        self.model = model

    @abstractmethod
    def __call__(self, tokens, position_ids, attention_mask):
        pass
        self.constant = 512

        # Initialize inference parameters.
        self.inference_params = InferenceParams(max_batch_size,
                                                max_sequence_len)


class SimplePipeliningForwardStep(ForwardStepBase):
    def __call__(self, tokens, position_ids, attention_mask):
        if tokens.size(0) * tokens.size(1) >= self.constant:
            micro_batch_size = max(1, self.constant // tokens.size(1))
            return _with_pipelining_forward_step(self.model, tokens,
                                                 position_ids,
                                                 attention_mask,
                                                 self.inference_params,
                                                 micro_batch_size)
        else:
            return _no_pipelining_forward_step(self.model, tokens,
                                               position_ids,
                                               attention_mask,
                                               self.inference_params)
            
    def __init__(self, model, batch_size, micro_batch_size, max_sequence_len):
        super().__init__(model)

        self.batch_size = batch_size
        # Divide the batch dimension into micro batches.
        self.num_micro_batches, last_chunk = divmod(batch_size,
                                                    micro_batch_size)
        self.micro_batch_size_list = []
        self.batch_dim_start_index = [0]
        for i in range(self.num_micro_batches):
            self.micro_batch_size_list.append(micro_batch_size)
            self.batch_dim_start_index.append((i + 1) * micro_batch_size)
        if last_chunk > 0:
            self.num_micro_batches += 1
            self.micro_batch_size_list.append(last_chunk)
            self.batch_dim_start_index.append(batch_size)

        self.inference_params = InferenceParams(self.micro_batch_size_list,
                                                max_sequence_len)
def _get_recv_buffer_dtype(args):
    """Receive happens between the layers."""
    if args.fp32_residual_connection:
        return torch.float
    return args.params_dtype


    def __call__(self, tokens, position_ids, attention_mask):

        # Need to tell p2p_communicate functions the correct size.
def _allocate_recv_buffer(batch_size, sequence_length):
    """Receive happens between the layers with size [s, b, h]."""
    if mpu.is_pipeline_first_stage():
        return None
    args = get_args()
        orig_seq_length = args.seq_length
        args.seq_length = tokens.size(1)
        assert args.seq_length <= self.inference_params.max_sequence_len

        # Preallocate memory for output logits.
        logits = None
        if mpu.is_pipeline_last_stage():
            logits = torch.empty(tokens.size(0),
                                 tokens.size(1),
                                 args.padded_vocab_size,
                                 dtype=torch.float32,
    recv_size = (sequence_length, batch_size, args.hidden_size)
    return torch.empty(recv_size,
                       dtype=_get_recv_buffer_dtype(args),
                       device=torch.cuda.current_device())

        # Pileline using micro batches.
        for micro_batch_index in range(self.num_micro_batches):
            # Set micro-batch size and index.
            self.inference_params.micro_batch_index = micro_batch_index
            args.micro_batch_size = self.micro_batch_size_list[
                micro_batch_index]
            # Slice among the batch dimenion.
            start = self.batch_dim_start_index[micro_batch_index]
            end = self.batch_dim_start_index[micro_batch_index + 1]
            tokens2use = tokens[start:end, ...]
            position_ids2use = position_ids[start:end, ...]


def _forward_step_helper(model, tokens, position_ids, attention_mask,
                         inference_params, recv_buffer=None):
    """Single forward step. Update the allocate memory flag so
    only the first time the memory is allocated."""
    batch_size = tokens.size(0)
    sequence_length = tokens.size(1)
    if recv_buffer is None:
        recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)

    # Receive from previous stage.
            input_tensor = recv_forward()
    if not mpu.is_pipeline_first_stage():
        torch.distributed.recv(recv_buffer,
                               src=mpu.get_pipeline_model_parallel_prev_rank())

    # Forward pass through the model.
            self.model.set_input_tensor(input_tensor)
            output_tensor = self.model(tokens2use, position_ids2use,
                                       attention_mask,
                                       inference_params=self.inference_params)
    model.set_input_tensor(recv_buffer)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)

    # Send output to the next stage.
            send_forward(output_tensor)
    if not mpu.is_pipeline_last_stage():
        torch.distributed.send(output_tensor,
                               mpu.get_pipeline_model_parallel_next_rank())

            # Reset the sequence lenght to whatwever it was before.
    # Make sure we do not allocate context memory anymore.
            if self.inference_params.allocate_key_value_memory:
                self.inference_params.allocate_key_value_memory = False
    if inference_params.allocate_key_value_memory:
        inference_params.allocate_key_value_memory = False

            if mpu.is_pipeline_last_stage():
                logits[start:end, ...] = output_tensor

        # Adjust the sequence length back to whatever it was before.
        args.seq_length = orig_seq_length
    return output_tensor

        return logits


def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                inference_params, recv_buffer=None):

    # Run a simple forward pass.
    output_tensor = _forward_step_helper(model, tokens, position_ids,
                                         attention_mask, inference_params,
                                         recv_buffer=recv_buffer)
    # Update the sequence length offset.
    inference_params.sequence_len_offset += tokens.size(1)

    logits = None
    if mpu.is_pipeline_last_stage():
        logits = output_tensor

class NoPipeliningForwardStep(ForwardStepBase):
    return logits

    def __init__(self, model, batch_size, max_sequence_len):
        super().__init__(model)

        self.inference_params = InferenceParams([batch_size], max_sequence_len)
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):

    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)

    def __call__(self, tokens, position_ids, attention_mask):
    # Divide the batch dimension into micro batches.
    num_micro_batches, last_chunk = divmod(batch_size,
                                           micro_batch_size)
    if last_chunk > 0:
        num_micro_batches += 1

        # Need to tell p2p_communicate functions the correct size.
    # Preallocate memory for output logits.
    logits = None
    if mpu.is_pipeline_last_stage():
        args = get_args()
        orig_seq_length = args.seq_length
        args.seq_length = tokens.shape[1]
        assert args.seq_length <= self.inference_params.max_sequence_len
        args.micro_batch_size = tokens.shape[0]
        assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0]
        assert self.inference_params.micro_batch_index == 0
        logits = torch.empty(
            (batch_size, sequence_length, args.padded_vocab_size),
            dtype=torch.float32, device=torch.cuda.current_device())

        # Receive from previous stage.
        input_tensor = recv_forward()
    # Preallocate recv buffer.
    recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)

        # Forward pass through the model.
        self.model.set_input_tensor(input_tensor)
        output_tensor = self.model(tokens, position_ids, attention_mask,
                                   inference_params=self.inference_params)
    for micro_batch_index in range(num_micro_batches):
        # Slice among the batch dimenion.
        start = micro_batch_index * micro_batch_size
        end = min(start + micro_batch_size, batch_size)
        this_micro_batch_size = end - start
        tokens2use = tokens[start:end, ...]
        position_ids2use = position_ids[start:end, ...]

        # Send output to the next stage.
        send_forward(output_tensor)
        # Run a simple forward pass.
        if this_micro_batch_size != micro_batch_size:
            recv_buffer = None
        output = _forward_step_helper(model, tokens2use, position_ids2use,
                                      attention_mask, inference_params,
                                      recv_buffer=recv_buffer)

        # Reset the sequence lenght to whatwever it was before.
        args.seq_length = orig_seq_length
        # Make sure we do not allocate context memory anymore.
        if self.inference_params.allocate_key_value_memory:
            self.inference_params.allocate_key_value_memory = False
        # Adjust the batch size offset to account for the micro-batch.
        inference_params.batch_size_offset += this_micro_batch_size

        return output_tensor
        # Copy logits.
        if mpu.is_pipeline_last_stage():
            logits[start:end, ...] = output

    # Once we are done with all the micro-batches, we can
    # adjust the sequence length offset.
    inference_params.sequence_len_offset += sequence_length
    # and reset the batch size offset
    inference_params.batch_size_offset = 0

    return logits
+4 −5
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ from .communication import (
    copy_from_last_to_first_pipeline_stage,
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step_provider
from .forward_step import ForwardStep
from .sampling import sample


@@ -66,8 +66,7 @@ def generate_tokens_probs_and_return_on_first_stage(
    max_sequence_length = min(max_sequence_length, args.max_position_embeddings)

    # forward step.
    forward_step = forward_step_provider(model, batch_size, 4,
                                         max_sequence_length)
    forward_step = ForwardStep(model, batch_size, max_sequence_length)

    # Added termination_id to support the case that we want to terminate the
    # generation once that id is generated.
@@ -190,8 +189,8 @@ def generate_tokens_probs_and_return_on_first_stage(
                done = torch.all(is_generation_done)
            done = broadcast_from_last_pipeline_stage(1, torch.uint8,
                                                      tensor=done)
            if done:
                break
            #if done:
            #    break

    # ===================================================
    # Update the length of based on max generated length.
+25 −31
Original line number Diff line number Diff line
@@ -180,9 +180,8 @@ class ParallelAttention(MegatronModule):
            skip_bias_add=True)

        # Inference key-value memory
        self.inference_key_memory_list = None
        self.inference_value_memory_list = None
        self.inference_current_sequence_len_list = None
        self.inference_key_memory = None
        self.inference_value_memory = None


    def _allocate_memory(self, inference_max_sequence_len, batch_size):
@@ -206,22 +205,17 @@ class ParallelAttention(MegatronModule):
        if inference_params:
            if inference_params.allocate_key_value_memory:
                inf_max_seq_len = inference_params.max_sequence_len
                inf_batch_sizes = inference_params.micro_batch_size_list
                self.inference_key_memory_list = [
                    self._allocate_memory(inf_max_seq_len, inf_batch_size)
                    for inf_batch_size in inf_batch_sizes]
                self.inference_value_memory_list = [
                    self._allocate_memory(inf_max_seq_len, inf_batch_size)
                    for inf_batch_size in inf_batch_sizes]
                self.inference_current_sequence_len_list = [
                    0 for _ in inf_batch_sizes]
                inf_max_batch_size = inference_params.max_batch_size
                self.inference_key_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size)
                self.inference_value_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size)
        # This is added for safety. In case inference_params
        # is not provided, make sure there is no potential memory left
        # from previous inference.
        else:
            self.inference_key_memory_list = None
            self.inference_value_memory_list = None
            self.inference_current_sequence_len_list = None
            self.inference_value_memory = None
            self.inference_current_sequence_len = None

        # =====================
        # Query, Key, and Value
@@ -269,23 +263,23 @@ class ParallelAttention(MegatronModule):
        # ==================================

        if inference_params:
            inf_batch_index = inference_params.micro_batch_index
            assert key_layer.size(1) == \
                inference_params.micro_batch_size_list[inf_batch_index]
            # Adjust the range variables.
            start = self.inference_current_sequence_len_list[inf_batch_index]
            end = start + key_layer.size(0)
            assert end <= inference_params.max_sequence_len
            self.inference_current_sequence_len_list[inf_batch_index] = end
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= self.inference_key_memory.size(1)
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= self.inference_key_memory.size(0)
            # Copy key and values.
            self.inference_key_memory_list[inf_batch_index][start:end, ...] \
                = key_layer
            self.inference_value_memory_list[inf_batch_index][start:end, ...] \
                = value_layer
            key_layer = \
                self.inference_key_memory_list[inf_batch_index][:end, ...]
            value_layer = \
                self.inference_value_memory_list[inf_batch_index][:end, ...]
            self.inference_key_memory[sequence_start:sequence_end,
                                      batch_start:batch_end,
                                      ...] = key_layer
            self.inference_value_memory[sequence_start:sequence_end,
                                        batch_start:batch_end,
                                        ...] = value_layer
            key_layer = self.inference_key_memory[
                :sequence_end, batch_start:batch_end, ...]
            value_layer = self.inference_value_memory[
                :sequence_end, batch_start:batch_end, ...]


        # ===================================