Commit 5489bda9 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

More comments and some cleanup (e.g., better variable names)

parent 626645c0
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -554,12 +554,24 @@ class ParallelTransformer(MegatronModule):
                self_attn_mask_type=self_attn_mask_type)
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
                'num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size'
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
                    args.num_layers // args.virtual_pipeline_model_parallel_size) + \
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
            # Each stage gets a contiguous set of layers.
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
        self.layers = torch.nn.ModuleList(
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
+13 −9
Original line number Diff line number Diff line
@@ -271,10 +271,8 @@ def get_pipeline_model_parallel_rank():
def is_pipeline_first_stage(ignore_virtual=False):
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
    if not ignore_virtual:
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
            _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != 0:
        if get_virtual_pipeline_model_parallel_world_size() is not None and \
            get_virtual_pipeline_model_parallel_rank() != 0:
            return False
    return get_pipeline_model_parallel_rank() == 0

@@ -282,11 +280,11 @@ def is_pipeline_first_stage(ignore_virtual=False):
def is_pipeline_last_stage(ignore_virtual=False):
    """Return True if in the last pipeline model-parallel stage, False otherwise."""
    if not ignore_virtual:
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        if _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None and \
            _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK != (
                _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - 1):
        virtual_pipeline_model_parallel_world_size = \
            get_virtual_pipeline_model_parallel_world_size()
        if virtual_pipeline_model_parallel_world_size is not None and \
            get_virtual_pipeline_model_parallel_rank() != (
                virtual_pipeline_model_parallel_world_size - 1):
            return False
    return get_pipeline_model_parallel_rank() == (
        get_pipeline_model_parallel_world_size() - 1)
@@ -304,6 +302,12 @@ def set_virtual_pipeline_model_parallel_rank(rank):
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


def get_tensor_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
+42 −10
Original line number Diff line number Diff line
@@ -23,7 +23,24 @@ from megatron import mpu

def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                 use_ring_exchange=False):
    """Communicate tensors between stages."""
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.

    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
@@ -50,6 +67,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)

    # Split tensor into smaller chunks if using scatter-gather optimization.
    if args.scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
@@ -67,27 +85,32 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
    else:
        ops = []
        if tensor_send_prev is not None:
            send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(send_prev_op)
        if tensor_recv_prev is not None:
            recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(recv_prev_op)
        if tensor_send_next is not None:
            send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(send_next_op)
        if tensor_recv_next is not None:
            recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        reqs = torch.distributed.batch_isend_irecv(ops)
        for req in reqs:
            req.wait()
    # To protect against race condition when using batch_isend_irecv().
    torch.cuda.synchronize()

    tensor_recv_prev_before = tensor_recv_prev
    # If using scatter-gather optimization, gather smaller chunks.
    if args.scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
@@ -101,6 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,


def recv_forward(timers=None, use_ring_exchange=False):
    """Receive tensor from previous rank in pipeline (forward receive)."""
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
@@ -118,6 +142,7 @@ def recv_forward(timers=None, use_ring_exchange=False):


def recv_backward(timers=None, use_ring_exchange=False):
    """Receive tensor from next rank in pipeline (backward receive)."""
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
@@ -135,6 +160,7 @@ def recv_backward(timers=None, use_ring_exchange=False):


def send_forward(output_tensor, timers=None, use_ring_exchange=False):
    """Send tensor to next rank in pipeline (forward send)."""
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
@@ -149,6 +175,7 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):


def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
    """Send tensor to previous rank in pipeline (backward send)."""
    if not mpu.is_pipeline_first_stage():
        if timers is not None:
            timers('backward-send').start()
@@ -163,6 +190,7 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):


def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False):
    """Batched send and recv with next rank in pipeline."""
    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
@@ -180,6 +208,7 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal


def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False):
    """Batched send and recv with previous rank in pipeline."""
    if mpu.is_pipeline_first_stage():
        input_tensor = None
    else:
@@ -197,6 +226,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange


def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
    """Batched recv from previous rank and send to next rank in pipeline."""
    if timers is not None:
        timers('forward-send-forward-recv').start()
    input_tensor, _ = _communicate(
@@ -211,6 +241,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):


def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
    """Batched recv from next rank and send to previous rank in pipeline."""
    if timers is not None:
        timers('backward-send-backward-recv').start()
    _, output_tensor_grad = _communicate(
@@ -227,6 +258,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def send_forward_backward_recv_forward_backward(
        output_tensor, input_tensor_grad, recv_prev,
        recv_next, timers=None):
    """Batched send and recv with previous and next ranks in pipeline."""
    if timers is not None:
        timers('forward-backward-send-forward-backward-recv').start()
    input_tensor, output_tensor_grad = _communicate(
+12 −11
Original line number Diff line number Diff line
@@ -136,19 +136,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

    def get_model_chunk_id(k, forward):
    def get_model_chunk_id(microbatch_id, forward):
        """Helper method to get the model chunk ID given the iteration number."""
        k_in_group = k % (pipeline_parallel_size * num_model_chunks)
        i = k_in_group // pipeline_parallel_size
        microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
        model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
        if not forward:
            i = (num_model_chunks - i - 1)
        return i
            model_chunk_id = (num_model_chunks - model_chunk_id - 1)
        return model_chunk_id

    def forward_step_helper(k):
    def forward_step_helper(microbatch_id):
        """Helper method to run forward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        forward_step())."""
        model_chunk_id = get_model_chunk_id(k, forward=True)
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_first_stage():
@@ -164,11 +164,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat

        return output_tensor

    def backward_step_helper(k):
    def backward_step_helper(microbatch_id):
        """Helper method to run backward step with model split into chunks
        (run set_virtual_pipeline_model_parallel_rank() before calling
        backward_step())."""
        model_chunk_id = get_model_chunk_id(k, forward=False)
        model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
        mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

        if mpu.is_pipeline_last_stage():
@@ -317,8 +317,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    return losses_reduced


def forward_backward_pipelining(forward_step_func, data_iterator, model,
                                optimizer, timers, forward_only):
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
                                                     model, optimizer, timers,
                                                     forward_only):
    """Run non-interleaved 1F1B schedule, with communication between pipeline
    stages.

+5 −5
Original line number Diff line number Diff line
@@ -50,7 +50,7 @@ from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory

@@ -340,7 +340,7 @@ def train_step(forward_step_func, data_iterator,
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    losses_reduced = forward_backward_func(
@@ -681,7 +681,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
                if args.virtual_pipeline_model_parallel_size is not None:
                    forward_backward_func = forward_backward_pipelining_with_interleaving
                else:
                    forward_backward_func = forward_backward_pipelining
                    forward_backward_func = forward_backward_pipelining_without_interleaving
            else:
                forward_backward_func = forward_backward_no_pipelining
            loss_dicts = forward_backward_func(
@@ -692,8 +692,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
                # Reduce across processes.
                for loss_dict in loss_dicts:
                    for key in loss_dict:
                        total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
                            loss_dict[key]
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]

            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
                                           * args.micro_batch_size \