Commit 8f160844 authored by mshoeybi's avatar mshoeybi
Browse files

simple pipelining works

parent dee8707e
Loading
Loading
Loading
Loading
+114 −41
Original line number Diff line number Diff line
@@ -15,15 +15,33 @@

"""Forward step utilities."""

from abc import ABC
from abc import abstractmethod
from collections.abc import Iterable
from enum import Enum

from megatron.p2p_communication import recv_forward, send_forward
from megatron import get_args
import torch

from megatron import (
    get_args,
    mpu)
from megatron.p2p_communication import (
    recv_forward,
    send_forward)

class ForwardStepTypes(Enum):
    NO_PIPELINING = 1

def forward_step_provider(model,
                          batch_size,
                          micro_batch_size,
                          max_sequence_len):

    args = get_args()

    if args.pipeline_model_parallel_size == 1 or micro_batch_size >= batch_size:
        return NoPipeliningForwardStep(model, batch_size, max_sequence_len)

    return SimplePipeliningForwardStep(model, batch_size,
                                       micro_batch_size,
                                       max_sequence_len)



@@ -37,12 +55,12 @@ class InferenceParams:
        self.micro_batch_size_list = micro_batch_size_list
        self.max_sequence_len = max_sequence_len
        self.allocate_key_value_memory = True
        self.micro_batch_size_index = 0
        self.micro_batch_index = 0


class InferenceForwardStep:
class ForwardStepBase(ABC):

    def __init__(self, model, batch_size, max_sequence_len):
    def __init__(self, model):

        if isinstance(model, Iterable):
            for this_model in model:
@@ -51,70 +69,125 @@ class InferenceForwardStep:
            model.eval()
        self.model = model

        self.inference_params = InferenceParams([batch_size], max_sequence_len)
        self.forward_step_type = ForwardStepTypes.NO_PIPELINING
    @abstractmethod
    def __call__(self, tokens, position_ids, attention_mask):
        pass


    def __call__(self, tokens, position_ids, attention_mask):

        if self.forward_step_type == ForwardStepTypes.NO_PIPELINING:
            return self._forward_step_no_pipelining(tokens, position_ids,
                                                    attention_mask)
class SimplePipeliningForwardStep(ForwardStepBase):

        raise Exception('unknown forward step type {}'.format(
            self.forward_step_type))
    def __init__(self, model, batch_size, micro_batch_size, max_sequence_len):
        super().__init__(model)

        self.batch_size = batch_size
        # Divide the batch dimension into micro batches.
        self.num_micro_batches, last_chunk = divmod(batch_size,
                                                    micro_batch_size)
        self.micro_batch_size_list = []
        self.batch_dim_start_index = [0]
        for i in range(self.num_micro_batches):
            self.micro_batch_size_list.append(micro_batch_size)
            self.batch_dim_start_index.append((i + 1) * micro_batch_size)
        if last_chunk > 0:
            self.num_micro_batches += 1
            self.micro_batch_size_list.append(last_chunk)
            self.batch_dim_start_index.append(batch_size)

    def _forward_step_no_pipelining(self, tokens, position_ids, attention_mask):
        self.inference_params = InferenceParams(self.micro_batch_size_list,
                                                max_sequence_len)


    def __call__(self, tokens, position_ids, attention_mask):

        # Need to tell p2p_communicate functions the correct size.
        args = get_args()
        orig_seq_length = args.seq_length
        args.seq_length = tokens.shape[1]
        args.seq_length = tokens.size(1)
        assert args.seq_length <= self.inference_params.max_sequence_len
        args.micro_batch_size = tokens.shape[0]
        assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0]
        assert self.inference_params.micro_batch_size_index == 0

        # Preallocate memory for output logits.
        logits = None
        if mpu.is_pipeline_last_stage():
            logits = torch.empty(tokens.size(0),
                                 tokens.size(1),
                                 args.padded_vocab_size,
                                 dtype=torch.float32,
                                 device=torch.cuda.current_device())

        # Pileline using micro batches.
        for micro_batch_index in range(self.num_micro_batches):
            # Set micro-batch size and index.
            self.inference_params.micro_batch_index = micro_batch_index
            args.micro_batch_size = self.micro_batch_size_list[
                micro_batch_index]
            # Slice among the batch dimenion.
            start = self.batch_dim_start_index[micro_batch_index]
            end = self.batch_dim_start_index[micro_batch_index + 1]
            tokens2use = tokens[start:end, ...]
            position_ids2use = position_ids[start:end, ...]

            # Receive from previous stage.
            input_tensor = recv_forward()

            # Forward pass through the model.
            self.model.set_input_tensor(input_tensor)
        output_tensor = self.model(tokens, position_ids, attention_mask,
            output_tensor = self.model(tokens2use, position_ids2use,
                                       attention_mask,
                                       inference_params=self.inference_params)

            # Send output to the next stage.
            send_forward(output_tensor)

            # Reset the sequence lenght to whatwever it was before.
        args.seq_length = orig_seq_length
            # Make sure we do not allocate context memory anymore.
            if self.inference_params.allocate_key_value_memory:
                self.inference_params.allocate_key_value_memory = False

        return output_tensor
            if mpu.is_pipeline_last_stage():
                logits[start:end, ...] = output_tensor

        # Adjust the sequence length back to whatever it was before.
        args.seq_length = orig_seq_length

        return logits


def forward_step(model, tokens, position_ids, attention_mask, inference_params):

    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
class NoPipeliningForwardStep(ForwardStepBase):

    def __init__(self, model, batch_size, max_sequence_len):
        super().__init__(model)

        self.inference_params = InferenceParams([batch_size], max_sequence_len)


    def __call__(self, tokens, position_ids, attention_mask):

        # Need to tell p2p_communicate functions the correct size.
        args = get_args()
        orig_seq_length = args.seq_length
        args.seq_length = tokens.shape[1]
        assert args.seq_length <= self.inference_params.max_sequence_len
        args.micro_batch_size = tokens.shape[0]
        assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0]
        assert self.inference_params.micro_batch_index == 0

        # Receive from previous stage.
        input_tensor = recv_forward()

        # Forward pass through the model.
    model.set_input_tensor(input_tensor)
    output_tensor = model(tokens, position_ids, attention_mask,
                          inference_params=inference_params)
        self.model.set_input_tensor(input_tensor)
        output_tensor = self.model(tokens, position_ids, attention_mask,
                                   inference_params=self.inference_params)

        # Send output to the next stage.
        send_forward(output_tensor)

        # Reset the sequence lenght to whatwever it was before.
        args.seq_length = orig_seq_length
        # Make sure we do not allocate context memory anymore.
        if self.inference_params.allocate_key_value_memory:
            self.inference_params.allocate_key_value_memory = False

        return output_tensor
+3 −2
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ from .communication import (
    copy_from_last_to_first_pipeline_stage,
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage)
from .forward_step import InferenceForwardStep
from .forward_step import forward_step_provider
from .sampling import sample


@@ -66,7 +66,8 @@ def generate_tokens_probs_and_return_on_first_stage(
    max_sequence_length = min(max_sequence_length, args.max_position_embeddings)

    # forward step.
    forward_step = InferenceForwardStep(model, batch_size, max_sequence_length)
    forward_step = forward_step_provider(model, batch_size, 4,
                                         max_sequence_length)

    # Added termination_id to support the case that we want to terminate the
    # generation once that id is generated.
+6 −5
Original line number Diff line number Diff line
@@ -269,18 +269,19 @@ class ParallelAttention(MegatronModule):
        # ==================================

        if inference_params:
            inf_batch_index = inference_params.micro_batch_size_index
            inf_batch_index = inference_params.micro_batch_index
            assert key_layer.size(1) == \
                inference_params.micro_batch_size_list[inf_batch_index]
            # Adjust the range variables.
            start = self.inference_current_sequence_len_list[inf_batch_index]
            end = start + key_layer.size(0)
            assert end <= inference_params.max_sequence_len
            self.inference_current_sequence_len_list[inf_batch_index] = end
            # Copy key and values.
            self.inference_key_memory_list[inf_batch_index][start:end, ...] =\
                key_layer
            self.inference_value_memory_list[inf_batch_index][start:end, ...] =\
                value_layer
            self.inference_key_memory_list[inf_batch_index][start:end, ...] \
                = key_layer
            self.inference_value_memory_list[inf_batch_index][start:end, ...] \
                = value_layer
            key_layer = \
                self.inference_key_memory_list[inf_batch_index][:end, ...]
            value_layer = \