Commit ca8dd4ac authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'mp_merger' into 'master'

Model parallel merger

See merge request ADLR/megatron-lm!28
parents 5df85022 57c2060f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ def add_model_config_args(parser):
                       help='dropout probability for hidden state transformer')
    group.add_argument('--max-position-embeddings', type=int, default=512,
                       help='maximum number of position embeddings to use')
    group.add_argument('--vocab-size', type=int, default=30522,
    group.add_argument('--vocab-size', type=int, default=None,
                       help='vocab size to use for non-character-level '
                       'tokenization. This value will only be used when '
                       'creating a tokenizer')
+2 −0
Original line number Diff line number Diff line
@@ -83,6 +83,8 @@ class BertLMHead(MegatronModule):

        self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
        self.bias.model_parallel = True
        self.bias.partition_dim = 0
        self.bias.stride = 1
        self.parallel_output = parallel_output

        self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
+1 −0
Original line number Diff line number Diff line
@@ -372,6 +372,7 @@ class ParallelTransformerLayer(MegatronModule):
    def __init__(self, hyperparameters, attention_mask_func, layer_number):

        super(ParallelTransformerLayer, self).__init__()
        self.layer_number = layer_number

        self.apply_residual_connection_post_layernorm \
            = hyperparameters['apply_residual_connection_post_layernorm']
+22 −0
Original line number Diff line number Diff line
@@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None
_MPU_RANK = None


def initialize_model_parallel(model_parallel_size_):
    """
@@ -99,13 +103,31 @@ def get_data_parallel_group():
    return _DATA_PARALLEL_GROUP


def set_model_parallel_world_size(world_size):
    """Set the model parallel size"""
    global _MPU_WORLD_SIZE
    _MPU_WORLD_SIZE = world_size


def get_model_parallel_world_size():
    """Return world size for the model parallel group."""
    global _MPU_WORLD_SIZE
    if _MPU_WORLD_SIZE is not None:
        return _MPU_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_model_parallel_group())


def set_model_parallel_rank(rank):
    """Set model parallel rank."""
    global _MPU_RANK
    _MPU_RANK = rank


def get_model_parallel_rank():
    """Return my rank for the model parallel group."""
    global _MPU_RANK
    if _MPU_RANK is not None:
        return _MPU_RANK
    return torch.distributed.get_rank(group=get_model_parallel_group())


+7 −4
Original line number Diff line number Diff line
@@ -46,6 +46,11 @@ def _initialize_affine_weight(weight, output_size, input_size,

    Build the master weight on all processes and scatter
    the relevant chunk."""

    weight.model_parallel = True
    weight.partition_dim = partition_dim
    weight.stride = stride

    # If we only use 1 process for model parallelism, bypass scatter.
    world_size = get_model_parallel_world_size()
    if world_size == 1:
@@ -108,7 +113,6 @@ class VocabParallelEmbedding(torch.nn.Module):
        # Allocate weights.
        self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition,
                                             self.embedding_dim))
        self.weight.model_parallel = True
        # And initialize.
        _initialize_affine_weight(
            self.weight, self.num_embeddings, self.embedding_dim,
@@ -165,7 +169,6 @@ class ParallelEmbedding(torch.nn.Module):
        # Allocate weights.
        self.weight = Parameter(torch.Tensor(self.num_embeddings,
                                             self.embedding_dim_per_partition))
        self.weight.model_parallel = True
        # And initialize.
        _initialize_affine_weight(
            self.weight, self.num_embeddings, self.embedding_dim,
@@ -220,10 +223,11 @@ class ColumnParallelLinear(torch.nn.Module):
        # we allocate the transpose.
        self.weight = Parameter(torch.Tensor(self.output_size_per_partition,
                                             self.input_size))
        self.weight.model_parallel = True
        if bias:
            self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
            self.bias.model_parallel = True
            self.bias.partition_dim = 0
            self.bias.stride = stride
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
@@ -294,7 +298,6 @@ class RowParallelLinear(torch.nn.Module):
        # we allocate the transpose.
        self.weight = Parameter(torch.Tensor(self.output_size,
                                             self.input_size_per_partition))
        self.weight.model_parallel = True
        if bias:
            self.bias = Parameter(torch.Tensor(self.output_size))
            # Always initialize bias to zero.
Loading