Commit 27fc4689 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Break up tensors sent between pipeline stages into smaller chunks that can be all-gathered

parent 8e922d5b
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -566,6 +566,8 @@ def _add_distributed_args(parser):
                       choices=['local', 'torch'],
                       help='which DistributedDataParallel implementation '
                       'to use.')
    group.add_argument('--scatter-gather-tensors-in-pipeline', action='store_true',
                       help='Use scatter/gather to optimize communication of tensors in pipeline')
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
+2 −0
Original line number Diff line number Diff line
@@ -59,6 +59,8 @@ from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks

from .utils import divide
from .utils import split_tensor_along_last_dim
+26 −2
Original line number Diff line number Diff line
@@ -29,20 +29,33 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    tensor_recv_prev = None
    tensor_recv_next = None
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
    if args.scatter_gather_tensors_in_pipeline:
        tensor_chunk_shape = (
            args.seq_length * args.micro_batch_size * args.hidden_size) // \
                    mpu.get_tensor_model_parallel_world_size()
    else:
        tensor_chunk_shape = tensor_shape
    dtype = args.params_dtype
    if args.fp32_residual_connection:
        dtype = torch.float
    if recv_prev:
        tensor_recv_prev = torch.empty(tensor_shape,
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
                                       requires_grad=True,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
        tensor_recv_next = torch.empty(tensor_shape,
        tensor_recv_next = torch.empty(tensor_chunk_shape,
                                       requires_grad=True,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

    if args.scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

    # Send tensors in both the forward and backward directions as appropriate.
    if use_ring_exchange:
        torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
@@ -71,6 +84,17 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
        reqs = torch.distributed.batch_isend_irecv(ops)
        for req in reqs:
            req.wait()
    torch.cuda.synchronize()

    tensor_recv_prev_before = tensor_recv_prev
    if args.scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()

    return tensor_recv_prev, tensor_recv_next

+2 −2
Original line number Diff line number Diff line
@@ -381,11 +381,11 @@ def train_step(forward_step_func, data_iterator,

    # Update parameters.
    timers('optimizer').start()
    update_successfull, grad_norm = optimizer.step()
    update_successful, grad_norm = optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    if update_successfull:
    if update_successful:
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size