Commit 417c7f6a authored by Boris Fomitchev's avatar Boris Fomitchev
Browse files

Changes for NeMo/lightning compatibility

parent 07ebf714
Loading
Loading
Loading
Loading
+0 −6
Original line number Diff line number Diff line
@@ -39,12 +39,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'

    # This is temporary WAR to make simple case like pytest calling with same args twice
    # Need to implement clean factory init.
    if mpu.model_parallel_is_initialized():
        return
    
    
    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
    set_global_variables(extra_args_provider=extra_args_provider,
+8 −1
Original line number Diff line number Diff line
@@ -88,13 +88,16 @@ def model_parallel_is_initialized():
        return False
    return True


def get_model_parallel_group():
    """Get the model parallel group the caller rank belongs to."""
    assert _MODEL_PARALLEL_GROUP is not None, \
        'model parallel group is not initialized'
    return _MODEL_PARALLEL_GROUP

def set_model_parallel_group(group):
    """Set model parallel group."""
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = group

def get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
@@ -102,6 +105,10 @@ def get_data_parallel_group():
        'data parallel group is not initialized'
    return _DATA_PARALLEL_GROUP

def set_data_parallel_group(group):
    """Set data parallel group."""
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = group

def set_model_parallel_world_size(world_size):
    """Set the model parallel size"""
+12 −8
Original line number Diff line number Diff line
@@ -127,18 +127,22 @@ class VocabParallelEmbedding(torch.nn.Module):
            self.num_embeddings_per_partition, 0, init_method)

    def forward(self, input_):
        if self.num_embeddings_per_partition < self.num_embeddings:
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
        if self.num_embeddings_per_partition < self.num_embeddings:
            output_parallel[input_mask, :] = 0.0
        # Reduce across all the model parallel GPUs.
        output = reduce_from_model_parallel_region(output_parallel)
+25 −12
Original line number Diff line number Diff line
@@ -15,20 +15,19 @@

import torch

from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank
from .utils import split_tensor_along_last_dim


def _reduce(input_):
    """All-reduce the the input tensor across model parallel group."""
    group = get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if torch.distributed.get_world_size(group=group) == 1:
    if get_model_parallel_world_size()==1:
        return input_

    # All-reduce.
    torch.distributed.all_reduce(input_, group=group)
    torch.distributed.all_reduce(input_, group=get_model_parallel_group())

    return input_

@@ -36,18 +35,17 @@ def _reduce(input_):
def _split(input_):
    """Split the tensor along its last dimension and keep the
    corresponding slice."""
    group = get_model_parallel_group()

    world_size = get_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if torch.distributed.get_world_size(group=group) == 1:
    if world_size==1:
        return input_

    # Split along last dimension.
    world_size = torch.distributed.get_world_size(group=group)
    input_list = split_tensor_along_last_dim(input_, world_size)

    # Note: torch.split does not create contiguous tensors by default.
    rank = torch.distributed.get_rank(group=group)
    rank = get_model_parallel_rank()
    output = input_list[rank].contiguous()

    return output
@@ -55,16 +53,15 @@ def _split(input_):

def _gather(input_):
    """Gather tensors and concatinate along the last dimension."""
    group = get_model_parallel_group()

    world_size = get_model_parallel_world_size()
    # Bypass the function if we are using only 1 GPU.
    if torch.distributed.get_world_size(group=group) == 1:
    if world_size==1:
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
    rank = torch.distributed.get_rank(group=group)
    world_size = torch.distributed.get_world_size(group=group)
    rank = get_model_parallel_rank()

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
@@ -79,6 +76,10 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return input_
    
    @staticmethod
    def forward(ctx, input_):
        return input_
@@ -91,6 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
    """All-redcue the input from the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return _reduce(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)
@@ -103,6 +108,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
    """Split the input and keep only the corresponding chuck to the rank."""

    @staticmethod
    def symbolic(graph, input_):
        return _split(input_)

    @staticmethod
    def forward(ctx, input_):
        return _split(input_)
@@ -115,6 +124,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function):
    """Gather the input from model parallel region and concatinate."""

    @staticmethod
    def symbolic(graph, input_):
        return _gather(input_)
    
    @staticmethod
    def forward(ctx, input_):
        return _gather(input_)