Commit ff2f0a05 authored by mshoeybi's avatar mshoeybi
Browse files

further refactoring

parent 390ddef8
Loading
Loading
Loading
Loading
+66 −8
Original line number Diff line number Diff line
@@ -15,13 +15,18 @@

"""Forward step utilities."""


import torch
from collections.abc import Iterable
from enum import Enum

from megatron.p2p_communication import recv_forward, send_forward
from megatron import get_args


class ForwardStepTypes(Enum):
    NO_PIPELINING = 1



class InferenceParams:

    def __init__(self, micro_batch_size_list, max_sequence_len):
@@ -31,10 +36,67 @@ class InferenceParams:

        self.micro_batch_size_list = micro_batch_size_list
        self.max_sequence_len = max_sequence_len
        self.allocate_key_value_memory = False
        self.allocate_key_value_memory = True
        self.micro_batch_size_index = 0


class InferenceForwardStep:

    def __init__(self, model, batch_size, max_sequence_len):

        if isinstance(model, Iterable):
            for this_model in model:
                this_model.eval()
        else:
            model.eval()
        self.model = model

        self.inference_params = InferenceParams([batch_size], max_sequence_len)
        self.forward_step_type = ForwardStepTypes.NO_PIPELINING


    def __call__(self, tokens, position_ids, attention_mask):

        if self.forward_step_type == ForwardStepTypes.NO_PIPELINING:
            return self._forward_step_no_pipelining(tokens, position_ids,
                                                    attention_mask)

        raise Exception('unknown forward step type {}'.format(
            self.forward_step_type))


    def _forward_step_no_pipelining(self, tokens, position_ids, attention_mask):

        # Need to tell p2p_communicate functions the correct size.
        args = get_args()
        orig_seq_length = args.seq_length
        args.seq_length = tokens.shape[1]
        assert args.seq_length <= self.inference_params.max_sequence_len
        args.micro_batch_size = tokens.shape[0]
        assert self.inference_params.micro_batch_size_list[0] == tokens.shape[0]
        assert self.inference_params.micro_batch_size_index == 0

        # Receive from previous stage.
        input_tensor = recv_forward()

        # Forward pass through the model.
        self.model.set_input_tensor(input_tensor)
        output_tensor = self.model(tokens, position_ids, attention_mask,
                                   inference_params=self.inference_params)

        # Send output to the next stage.
        send_forward(output_tensor)

        # Reset the sequence lenght to whatwever it was before.
        args.seq_length = orig_seq_length
        # Make sure we do not allocate context memory anymore.
        if self.inference_params.allocate_key_value_memory:
            self.inference_params.allocate_key_value_memory = False

        return output_tensor



def forward_step(model, tokens, position_ids, attention_mask, inference_params):

    # Hidden size changes when not using recompute, need to tell p2p_communicate
@@ -56,7 +118,3 @@ def forward_step(model, tokens, position_ids, attention_mask, inference_params):
    args.seq_length = orig_seq_length

    return output_tensor



+5 −14
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@

"""Generation utilities."""


import torch
import torch.nn.functional as F

@@ -25,7 +24,7 @@ from .communication import (
    copy_from_last_to_first_pipeline_stage,
    broadcast_from_last_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage)
from .forward_step import forward_step, InferenceParams
from .forward_step import InferenceForwardStep
from .sampling import sample


@@ -66,6 +65,9 @@ def generate_tokens_probs_and_return_on_first_stage(
    max_sequence_length = tokens.size(1)
    max_sequence_length = min(max_sequence_length, args.max_position_embeddings)

    # forward step.
    forward_step = InferenceForwardStep(model, batch_size, max_sequence_length)

    # Added termination_id to support the case that we want to terminate the
    # generation once that id is generated.
    if hasattr(args, 'eos_id'):
@@ -109,20 +111,10 @@ def generate_tokens_probs_and_return_on_first_stage(
    attention_mask, position_ids = _build_attention_mask_and_position_ids(
        tokens)

    # Set inference params
    inference_params = InferenceParams([batch_size], max_sequence_length)
    
    model.eval()
    with torch.no_grad():
        prev_context_length = 0
        for context_length in range(min_prompt_length, max_sequence_length):

            # If we are starting from scratch, allocate memory for the entire
            # context, otherwise  set this to false so the memory is not
            # reallocated.
            inference_params.allocate_key_value_memory = \
                (prev_context_length == 0)

            # Pick the slice that we need to pass through the network.
            tokens2use = tokens[:, prev_context_length:context_length]
            positions2use = position_ids[:, prev_context_length:context_length]
@@ -130,8 +122,7 @@ def generate_tokens_probs_and_return_on_first_stage(
                ..., prev_context_length:context_length, :context_length]

            # logits will be meanigful only in the last pipeline stage.
            logits = forward_step(model, tokens2use, positions2use,
                                  attention_mask2use, inference_params)
            logits = forward_step(tokens2use, positions2use, attention_mask2use)

            if mpu.is_pipeline_last_stage():
                # Always the last stage should have an output.