Loading arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -47,7 +47,7 @@ def add_model_config_args(parser): help='dropout probability for hidden state transformer') group.add_argument('--max-position-embeddings', type=int, default=512, help='maximum number of position embeddings to use') group.add_argument('--vocab-size', type=int, default=30522, group.add_argument('--vocab-size', type=int, default=None, help='vocab size to use for non-character-level ' 'tokenization. This value will only be used when ' 'creating a tokenizer') Loading megatron/model/bert_model.py +2 −0 Original line number Diff line number Diff line Loading @@ -83,6 +83,8 @@ class BertLMHead(MegatronModule): self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias.model_parallel = True self.bias.partition_dim = 0 self.bias.stride = 1 self.parallel_output = parallel_output self.dense = get_linear_layer(hidden_size, hidden_size, init_method) Loading megatron/model/transformer.py +1 −0 Original line number Diff line number Diff line Loading @@ -372,6 +372,7 @@ class ParallelTransformerLayer(MegatronModule): def __init__(self, hyperparameters, attention_mask_func, layer_number): super(ParallelTransformerLayer, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = hyperparameters['apply_residual_connection_post_layernorm'] Loading megatron/mpu/initialize.py +22 −0 Original line number Diff line number Diff line Loading @@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None # These values enable us to change the mpu sizes on the fly. _MPU_WORLD_SIZE = None _MPU_RANK = None def initialize_model_parallel(model_parallel_size_): """ Loading Loading @@ -99,13 +103,31 @@ def get_data_parallel_group(): return _DATA_PARALLEL_GROUP def set_model_parallel_world_size(world_size): """Set the model parallel size""" global _MPU_WORLD_SIZE _MPU_WORLD_SIZE = world_size def get_model_parallel_world_size(): """Return world size for the model parallel group.""" global _MPU_WORLD_SIZE if _MPU_WORLD_SIZE is not None: return _MPU_WORLD_SIZE return torch.distributed.get_world_size(group=get_model_parallel_group()) def set_model_parallel_rank(rank): """Set model parallel rank.""" global _MPU_RANK _MPU_RANK = rank def get_model_parallel_rank(): """Return my rank for the model parallel group.""" global _MPU_RANK if _MPU_RANK is not None: return _MPU_RANK return torch.distributed.get_rank(group=get_model_parallel_group()) Loading megatron/mpu/layers.py +7 −4 Original line number Diff line number Diff line Loading @@ -46,6 +46,11 @@ def _initialize_affine_weight(weight, output_size, input_size, Build the master weight on all processes and scatter the relevant chunk.""" weight.model_parallel = True weight.partition_dim = partition_dim weight.stride = stride # If we only use 1 process for model parallelism, bypass scatter. world_size = get_model_parallel_world_size() if world_size == 1: Loading Loading @@ -108,7 +113,6 @@ class VocabParallelEmbedding(torch.nn.Module): # Allocate weights. self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) self.weight.model_parallel = True # And initialize. _initialize_affine_weight( self.weight, self.num_embeddings, self.embedding_dim, Loading Loading @@ -165,7 +169,6 @@ class ParallelEmbedding(torch.nn.Module): # Allocate weights. self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim_per_partition)) self.weight.model_parallel = True # And initialize. _initialize_affine_weight( self.weight, self.num_embeddings, self.embedding_dim, Loading Loading @@ -220,10 +223,11 @@ class ColumnParallelLinear(torch.nn.Module): # we allocate the transpose. self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.input_size)) self.weight.model_parallel = True if bias: self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) self.bias.model_parallel = True self.bias.partition_dim = 0 self.bias.stride = stride # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() Loading Loading @@ -294,7 +298,6 @@ class RowParallelLinear(torch.nn.Module): # we allocate the transpose. self.weight = Parameter(torch.Tensor(self.output_size, self.input_size_per_partition)) self.weight.model_parallel = True if bias: self.bias = Parameter(torch.Tensor(self.output_size)) # Always initialize bias to zero. Loading Loading
arguments.py +1 −1 Original line number Diff line number Diff line Loading @@ -47,7 +47,7 @@ def add_model_config_args(parser): help='dropout probability for hidden state transformer') group.add_argument('--max-position-embeddings', type=int, default=512, help='maximum number of position embeddings to use') group.add_argument('--vocab-size', type=int, default=30522, group.add_argument('--vocab-size', type=int, default=None, help='vocab size to use for non-character-level ' 'tokenization. This value will only be used when ' 'creating a tokenizer') Loading
megatron/model/bert_model.py +2 −0 Original line number Diff line number Diff line Loading @@ -83,6 +83,8 @@ class BertLMHead(MegatronModule): self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias.model_parallel = True self.bias.partition_dim = 0 self.bias.stride = 1 self.parallel_output = parallel_output self.dense = get_linear_layer(hidden_size, hidden_size, init_method) Loading
megatron/model/transformer.py +1 −0 Original line number Diff line number Diff line Loading @@ -372,6 +372,7 @@ class ParallelTransformerLayer(MegatronModule): def __init__(self, hyperparameters, attention_mask_func, layer_number): super(ParallelTransformerLayer, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = hyperparameters['apply_residual_connection_post_layernorm'] Loading
megatron/mpu/initialize.py +22 −0 Original line number Diff line number Diff line Loading @@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None # These values enable us to change the mpu sizes on the fly. _MPU_WORLD_SIZE = None _MPU_RANK = None def initialize_model_parallel(model_parallel_size_): """ Loading Loading @@ -99,13 +103,31 @@ def get_data_parallel_group(): return _DATA_PARALLEL_GROUP def set_model_parallel_world_size(world_size): """Set the model parallel size""" global _MPU_WORLD_SIZE _MPU_WORLD_SIZE = world_size def get_model_parallel_world_size(): """Return world size for the model parallel group.""" global _MPU_WORLD_SIZE if _MPU_WORLD_SIZE is not None: return _MPU_WORLD_SIZE return torch.distributed.get_world_size(group=get_model_parallel_group()) def set_model_parallel_rank(rank): """Set model parallel rank.""" global _MPU_RANK _MPU_RANK = rank def get_model_parallel_rank(): """Return my rank for the model parallel group.""" global _MPU_RANK if _MPU_RANK is not None: return _MPU_RANK return torch.distributed.get_rank(group=get_model_parallel_group()) Loading
megatron/mpu/layers.py +7 −4 Original line number Diff line number Diff line Loading @@ -46,6 +46,11 @@ def _initialize_affine_weight(weight, output_size, input_size, Build the master weight on all processes and scatter the relevant chunk.""" weight.model_parallel = True weight.partition_dim = partition_dim weight.stride = stride # If we only use 1 process for model parallelism, bypass scatter. world_size = get_model_parallel_world_size() if world_size == 1: Loading Loading @@ -108,7 +113,6 @@ class VocabParallelEmbedding(torch.nn.Module): # Allocate weights. self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) self.weight.model_parallel = True # And initialize. _initialize_affine_weight( self.weight, self.num_embeddings, self.embedding_dim, Loading Loading @@ -165,7 +169,6 @@ class ParallelEmbedding(torch.nn.Module): # Allocate weights. self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim_per_partition)) self.weight.model_parallel = True # And initialize. _initialize_affine_weight( self.weight, self.num_embeddings, self.embedding_dim, Loading Loading @@ -220,10 +223,11 @@ class ColumnParallelLinear(torch.nn.Module): # we allocate the transpose. self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.input_size)) self.weight.model_parallel = True if bias: self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) self.bias.model_parallel = True self.bias.partition_dim = 0 self.bias.stride = stride # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() Loading Loading @@ -294,7 +298,6 @@ class RowParallelLinear(torch.nn.Module): # we allocate the transpose. self.weight = Parameter(torch.Tensor(self.output_size, self.input_size_per_partition)) self.weight.model_parallel = True if bias: self.bias = Parameter(torch.Tensor(self.output_size)) # Always initialize bias to zero. Loading