Commit d899988e authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Use batched send and recv instead of torch.distributed.ring_exchange()

parent 2348c99a
Loading
Loading
Loading
Loading
+0 −4
Original line number Diff line number Diff line
@@ -64,10 +64,6 @@ def parse_args(extra_args_provider=None, defaults={},
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
    if args.pipeline_model_parallel_size > 1:
        if "ring_exchange" not in dir(torch.distributed):
            raise Exception('PyTorch with torch.distributed.ring_exchange '
                            'needed to run pipeline MP!')
    # Checks.
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
+2 −0
Original line number Diff line number Diff line
@@ -36,6 +36,8 @@ from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel
+16 −2
Original line number Diff line number Diff line
@@ -276,16 +276,30 @@ def get_tensor_model_parallel_src_rank():
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size

def get_pipeline_model_parallel_first_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    return _PIPELINE_GLOBAL_RANKS[0]

def get_pipeline_model_parallel_last_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]

def get_pipeline_model_parallel_first_rank():
def get_pipeline_model_parallel_next_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    return _PIPELINE_GLOBAL_RANKS[0]
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]

def get_pipeline_model_parallel_prev_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]

def get_data_parallel_world_size():
    """Return world size for the data parallel group."""
+21 −6
Original line number Diff line number Diff line
@@ -325,7 +325,7 @@ def setup_model_and_optimizer(model_provider_func):


def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
    """Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
    """Communicate tensors between stages."""
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
@@ -348,11 +348,26 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
                                       dtype=dtype)

    # Send tensors in both the forward and backward directions as appropriate.
    torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                    tensor_recv_prev=tensor_recv_prev,
                                    tensor_send_next=tensor_send_next,
                                    tensor_recv_next=tensor_recv_next,
                                    group=mpu.get_pipeline_model_parallel_group())
    ops = []
    if tensor_send_prev is not None:
        send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
                                               mpu.get_pipeline_model_parallel_prev_rank())
        ops.append(send_prev_op)
    if tensor_recv_prev is not None:
        recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
                                               mpu.get_pipeline_model_parallel_prev_rank())
        ops.append(recv_prev_op)
    if tensor_send_next is not None:
        send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
                                               mpu.get_pipeline_model_parallel_next_rank())
        ops.append(send_next_op)
    if tensor_recv_next is not None:
        recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
                                               mpu.get_pipeline_model_parallel_next_rank())
        ops.append(recv_next_op)
    reqs = torch.distributed.batch_isend_irecv(ops)
    for req in reqs:
        req.wait()

    return tensor_recv_prev, tensor_recv_next