Commit 6e1bde1e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'slym/allreduce_custom' into 'main'

allreduce overlap with wgrad gemm

See merge request ADLR/megatron-lm!316
parents b7ae685f 4df8b7a2
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -462,6 +462,11 @@ def _add_training_args(parser):
    group.add_argument('--dataloader-type', type=str, default=None,
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
    group.add_argument('--no-async-tensor-model-parallel-allreduce',
                       action='store_true',
                       help='Disable asynchronous execution of '
                       'tensor-model-parallel all-reduce with weight '
                       'gradient compuation of a column-linear layer.')
    return parser


+12 −1
Original line number Diff line number Diff line
@@ -176,11 +176,22 @@ def _initialize_distributed():
            else:
                args.local_rank = device
            torch.cuda.set_device(device)
        # Increase cuda stream priority of NCCL ops when overlapping with other ops
        if (not args.no_async_tensor_model_parallel_allreduce and
                args.tensor_model_parallel_size > 1):
            from torch._C._distributed_c10d import ProcessGroupNCCL

            pg_options = ProcessGroupNCCL.Options()
            pg_options.is_high_priority_stream = True
            pg_options._timeout = timedelta(days=7)
        else:
            pg_options = None
        # Call the init process
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            world_size=args.world_size, rank=args.rank,
            timeout=timedelta(days=7))
            timeout=timedelta(days=7),
            pg_options=pg_options)

    # Set the tensor model-parallel, pipeline model-parallel, and
    # data-parallel communicators.
+51 −5
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter

from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
        return output


class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
    """
    Column-parallel linear layer execution with asynchronous all-reduce
    execution in backprop.
    """
    @staticmethod
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight)
        ctx.use_bias = bias is not None
        output = torch.matmul(input, weight.t())
        if bias is not None:
            output = output + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        use_bias = ctx.use_bias
        grad_input = grad_output.matmul(weight)
        # Asyncronous all-reduce
        handle = torch.distributed.all_reduce(
                grad_input, group=get_tensor_model_parallel_group(), async_op=True)
        # Delay the start of weight gradient computation shortly (3us) to have
        # all-reduce scheduled first and have GPU resources allocated
        _ = torch.empty(1, device=grad_output.device) + 1
        grad_weight = grad_output.t().matmul(input)
        grad_bias = grad_output.sum(dim=0) if use_bias else None
        handle.wait()
        return grad_input, grad_weight, grad_bias


class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

@@ -272,16 +304,30 @@ class ColumnParallelLinear(torch.nn.Module):
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)
        self.async_tensor_model_parallel_allreduce = (
                not args.no_async_tensor_model_parallel_allreduce and
                world_size > 1)



    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        if self.async_tensor_model_parallel_allreduce:
            input_shape = input_.shape
            input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
            # Maxtrix multiply with asynchronouse all-reduce execution
            output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
                    input_, self.weight, bias)
            output_parallel = output_parallel.view(
                    input_shape[0], input_shape[1], output_parallel.shape[1])
        else:
            # Set up backprop all-reduce.
            input_parallel = copy_to_tensor_model_parallel_region(input_)
        # Matrix multiply.

        bias = self.bias if not self.skip_bias_add else None
            # Matrix multiply.
            output_parallel = F.linear(input_parallel, self.weight, bias)

        if self.gather_output:
            # All-gather across the partitions.
            output = gather_from_tensor_model_parallel_region(output_parallel)