Loading megatron/initialize.py +0 −6 Original line number Diff line number Diff line Loading @@ -39,12 +39,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, # Make sure cuda is available. assert torch.cuda.is_available(), 'Megatron requires CUDA.' # This is temporary WAR to make simple case like pytest calling with same args twice # Need to implement clean factory init. if mpu.model_parallel_is_initialized(): return # Parse args, build tokenizer, and set adlr-autoresume, # tensorboard-writer, and timers. set_global_variables(extra_args_provider=extra_args_provider, Loading megatron/mpu/initialize.py +8 −1 Original line number Diff line number Diff line Loading @@ -88,13 +88,16 @@ def model_parallel_is_initialized(): return False return True def get_model_parallel_group(): """Get the model parallel group the caller rank belongs to.""" assert _MODEL_PARALLEL_GROUP is not None, \ 'model parallel group is not initialized' return _MODEL_PARALLEL_GROUP def set_model_parallel_group(group): """Set model parallel group.""" global _MODEL_PARALLEL_GROUP _MODEL_PARALLEL_GROUP = group def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" Loading @@ -102,6 +105,10 @@ def get_data_parallel_group(): 'data parallel group is not initialized' return _DATA_PARALLEL_GROUP def set_data_parallel_group(group): """Set data parallel group.""" global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = group def set_model_parallel_world_size(world_size): """Set the model parallel size""" Loading megatron/mpu/layers.py +12 −8 Original line number Diff line number Diff line Loading @@ -127,18 +127,22 @@ class VocabParallelEmbedding(torch.nn.Module): self.num_embeddings_per_partition, 0, init_method) def forward(self, input_): if self.num_embeddings_per_partition < self.num_embeddings: # Build the mask. input_mask = (input_ < self.vocab_start_index) | \ (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) # Mask the output embedding. if self.num_embeddings_per_partition < self.num_embeddings: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_model_parallel_region(output_parallel) Loading megatron/mpu/mappings.py +25 −12 Original line number Diff line number Diff line Loading @@ -15,20 +15,19 @@ import torch from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank from .utils import split_tensor_along_last_dim def _reduce(input_): """All-reduce the the input tensor across model parallel group.""" group = get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if torch.distributed.get_world_size(group=group) == 1: if get_model_parallel_world_size()==1: return input_ # All-reduce. torch.distributed.all_reduce(input_, group=group) torch.distributed.all_reduce(input_, group=get_model_parallel_group()) return input_ Loading @@ -36,18 +35,17 @@ def _reduce(input_): def _split(input_): """Split the tensor along its last dimension and keep the corresponding slice.""" group = get_model_parallel_group() world_size = get_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. if torch.distributed.get_world_size(group=group) == 1: if world_size==1: return input_ # Split along last dimension. world_size = torch.distributed.get_world_size(group=group) input_list = split_tensor_along_last_dim(input_, world_size) # Note: torch.split does not create contiguous tensors by default. rank = torch.distributed.get_rank(group=group) rank = get_model_parallel_rank() output = input_list[rank].contiguous() return output Loading @@ -55,16 +53,15 @@ def _split(input_): def _gather(input_): """Gather tensors and concatinate along the last dimension.""" group = get_model_parallel_group() world_size = get_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. if torch.distributed.get_world_size(group=group) == 1: if world_size==1: return input_ # Size and dimension. last_dim = input_.dim() - 1 rank = torch.distributed.get_rank(group=group) world_size = torch.distributed.get_world_size(group=group) rank = get_model_parallel_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ Loading @@ -79,6 +76,10 @@ def _gather(input_): class _CopyToModelParallelRegion(torch.autograd.Function): """Pass the input to the model parallel region.""" @staticmethod def symbolic(graph, input_): return input_ @staticmethod def forward(ctx, input_): return input_ Loading @@ -91,6 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function): class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-redcue the input from the model parallel region.""" @staticmethod def symbolic(graph, input_): return _reduce(input_) @staticmethod def forward(ctx, input_): return _reduce(input_) Loading @@ -103,6 +108,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): class _ScatterToModelParallelRegion(torch.autograd.Function): """Split the input and keep only the corresponding chuck to the rank.""" @staticmethod def symbolic(graph, input_): return _split(input_) @staticmethod def forward(ctx, input_): return _split(input_) Loading @@ -115,6 +124,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function): """Gather the input from model parallel region and concatinate.""" @staticmethod def symbolic(graph, input_): return _gather(input_) @staticmethod def forward(ctx, input_): return _gather(input_) Loading Loading
megatron/initialize.py +0 −6 Original line number Diff line number Diff line Loading @@ -39,12 +39,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, # Make sure cuda is available. assert torch.cuda.is_available(), 'Megatron requires CUDA.' # This is temporary WAR to make simple case like pytest calling with same args twice # Need to implement clean factory init. if mpu.model_parallel_is_initialized(): return # Parse args, build tokenizer, and set adlr-autoresume, # tensorboard-writer, and timers. set_global_variables(extra_args_provider=extra_args_provider, Loading
megatron/mpu/initialize.py +8 −1 Original line number Diff line number Diff line Loading @@ -88,13 +88,16 @@ def model_parallel_is_initialized(): return False return True def get_model_parallel_group(): """Get the model parallel group the caller rank belongs to.""" assert _MODEL_PARALLEL_GROUP is not None, \ 'model parallel group is not initialized' return _MODEL_PARALLEL_GROUP def set_model_parallel_group(group): """Set model parallel group.""" global _MODEL_PARALLEL_GROUP _MODEL_PARALLEL_GROUP = group def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" Loading @@ -102,6 +105,10 @@ def get_data_parallel_group(): 'data parallel group is not initialized' return _DATA_PARALLEL_GROUP def set_data_parallel_group(group): """Set data parallel group.""" global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = group def set_model_parallel_world_size(world_size): """Set the model parallel size""" Loading
megatron/mpu/layers.py +12 −8 Original line number Diff line number Diff line Loading @@ -127,18 +127,22 @@ class VocabParallelEmbedding(torch.nn.Module): self.num_embeddings_per_partition, 0, init_method) def forward(self, input_): if self.num_embeddings_per_partition < self.num_embeddings: # Build the mask. input_mask = (input_ < self.vocab_start_index) | \ (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) # Mask the output embedding. if self.num_embeddings_per_partition < self.num_embeddings: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_model_parallel_region(output_parallel) Loading
megatron/mpu/mappings.py +25 −12 Original line number Diff line number Diff line Loading @@ -15,20 +15,19 @@ import torch from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank from .utils import split_tensor_along_last_dim def _reduce(input_): """All-reduce the the input tensor across model parallel group.""" group = get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if torch.distributed.get_world_size(group=group) == 1: if get_model_parallel_world_size()==1: return input_ # All-reduce. torch.distributed.all_reduce(input_, group=group) torch.distributed.all_reduce(input_, group=get_model_parallel_group()) return input_ Loading @@ -36,18 +35,17 @@ def _reduce(input_): def _split(input_): """Split the tensor along its last dimension and keep the corresponding slice.""" group = get_model_parallel_group() world_size = get_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. if torch.distributed.get_world_size(group=group) == 1: if world_size==1: return input_ # Split along last dimension. world_size = torch.distributed.get_world_size(group=group) input_list = split_tensor_along_last_dim(input_, world_size) # Note: torch.split does not create contiguous tensors by default. rank = torch.distributed.get_rank(group=group) rank = get_model_parallel_rank() output = input_list[rank].contiguous() return output Loading @@ -55,16 +53,15 @@ def _split(input_): def _gather(input_): """Gather tensors and concatinate along the last dimension.""" group = get_model_parallel_group() world_size = get_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. if torch.distributed.get_world_size(group=group) == 1: if world_size==1: return input_ # Size and dimension. last_dim = input_.dim() - 1 rank = torch.distributed.get_rank(group=group) world_size = torch.distributed.get_world_size(group=group) rank = get_model_parallel_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ Loading @@ -79,6 +76,10 @@ def _gather(input_): class _CopyToModelParallelRegion(torch.autograd.Function): """Pass the input to the model parallel region.""" @staticmethod def symbolic(graph, input_): return input_ @staticmethod def forward(ctx, input_): return input_ Loading @@ -91,6 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function): class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-redcue the input from the model parallel region.""" @staticmethod def symbolic(graph, input_): return _reduce(input_) @staticmethod def forward(ctx, input_): return _reduce(input_) Loading @@ -103,6 +108,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): class _ScatterToModelParallelRegion(torch.autograd.Function): """Split the input and keep only the corresponding chuck to the rank.""" @staticmethod def symbolic(graph, input_): return _split(input_) @staticmethod def forward(ctx, input_): return _split(input_) Loading @@ -115,6 +124,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function): """Gather the input from model parallel region and concatinate.""" @staticmethod def symbolic(graph, input_): return _gather(input_) @staticmethod def forward(ctx, input_): return _gather(input_) Loading