Commit b31e1296 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fix-inference' into 'main'

Fix inference after T5 pipeline merge

See merge request ADLR/megatron-lm!332
parents cdc614cf f2c35bb0
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -357,6 +357,12 @@ class TransformerLanguageModel(MegatronModule):

    def set_input_tensor(self, input_tensor):
        """ See megatron.model.transformer.set_input_tensor()"""

        # This is usually handled in schedules.py but some inference code still
        # gives us non-lists or None
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]

        if self.add_encoder and self.add_decoder:
            assert len(input_tensor) == 1, \
                'input_tensor should only be length 1 for stage with both encoder and decoder'
+16 −9
Original line number Diff line number Diff line
@@ -53,6 +53,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None

    # Some legacy inference code doesn't set the tensor shape, do so now
    # for the normal values for gpt/bert. This could be removed if inference
    # code is changed to provide tensor_shape.
    if tensor_shape is None:
        tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)

    override_scatter_gather_tensors_in_pipeline = False
    if args.scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
@@ -143,7 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    return tensor_recv_prev, tensor_recv_next


def recv_forward(tensor_shape, dtype_=None, timers=None):
def recv_forward(tensor_shape=None, dtype_=None, timers=None):
    """Receive tensor from previous rank in pipeline (forward receive)."""

    if mpu.is_pipeline_first_stage():
@@ -163,7 +170,7 @@ def recv_forward(tensor_shape, dtype_=None, timers=None):
    return input_tensor


def recv_backward(tensor_shape, timers=None):
def recv_backward(tensor_shape=None, timers=None):
    """Receive tensor from next rank in pipeline (backward receive)."""
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
@@ -181,7 +188,7 @@ def recv_backward(tensor_shape, timers=None):
    return output_tensor_grad


def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
    """Send tensor to next rank in pipeline (forward send)."""

    if not mpu.is_pipeline_last_stage():
@@ -198,7 +205,7 @@ def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
            timers('forward-send').stop()


def send_backward(input_tensor_grad, tensor_shape, timers=None):
def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
    """Send tensor to previous rank in pipeline (backward send)."""
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
@@ -213,7 +220,7 @@ def send_backward(input_tensor_grad, tensor_shape, timers=None):
            timers('backward-send').stop()


def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
    """Batched send and recv with next rank in pipeline."""
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
@@ -231,7 +238,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
    return output_tensor_grad


def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
    """Batched send and recv with previous rank in pipeline."""
    if mpu.is_pipeline_first_stage():
        input_tensor = None
@@ -249,7 +256,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
    return input_tensor


def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=None):
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
    """Batched recv from previous rank and send to next rank in pipeline."""
    if timers is not None:
        timers('forward-send-forward-recv').start()
@@ -264,7 +271,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=Non
    return input_tensor


def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, timers=None):
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
    """Batched recv from next rank and send to previous rank in pipeline."""
    if timers is not None:
        timers('backward-send-backward-recv').start()
@@ -281,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, time

def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
        recv_next, tensor_shape, timers=None):
        recv_next, tensor_shape=None, timers=None):
    """Batched send and recv with previous and next ranks in pipeline."""
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').start()
+1 −1
Original line number Diff line number Diff line
@@ -193,7 +193,7 @@ def update_train_iters(args):
    print_rank_0('setting training iterations to {}'.format(args.train_iters))


def get_model(model_provider_func, model_type, wrap_with_ddp=True):
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
    """Build the model."""
    args = get_args()
    args.model_type = model_type