Loading megatron/arguments.py +5 −1 Original line number Diff line number Diff line Loading @@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={}, model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size assert args.world_size % model_parallel_size == 0, 'world size is not'\ ' divisible by tensor parallel size ({}) times pipeline paralle ' \ ' divisible by tensor parallel size ({}) times pipeline parallel ' \ 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size) args.data_parallel_size = args.world_size // model_parallel_size Loading Loading @@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={}, print('setting global batch size to {}'.format( args.global_batch_size), flush=True) assert args.global_batch_size > 0 if args.virtual_pipeline_model_parallel_size is not None: assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \ 'global batch size is not divisible by pipeline parallel size when '\ 'using interleaved schedule' # Parameters dtype. args.params_dtype = torch.float Loading megatron/p2p_communication.py +4 −3 Original line number Diff line number Diff line Loading @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import reduce import operator import torch from megatron import get_args Loading @@ -30,8 +32,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_recv_next = None tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if args.scatter_gather_tensors_in_pipeline: tensor_chunk_shape = ( args.seq_length * args.micro_batch_size * args.hidden_size) // \ tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ mpu.get_tensor_model_parallel_world_size() else: tensor_chunk_shape = tensor_shape Loading Loading
megatron/arguments.py +5 −1 Original line number Diff line number Diff line Loading @@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={}, model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size assert args.world_size % model_parallel_size == 0, 'world size is not'\ ' divisible by tensor parallel size ({}) times pipeline paralle ' \ ' divisible by tensor parallel size ({}) times pipeline parallel ' \ 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size) args.data_parallel_size = args.world_size // model_parallel_size Loading Loading @@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={}, print('setting global batch size to {}'.format( args.global_batch_size), flush=True) assert args.global_batch_size > 0 if args.virtual_pipeline_model_parallel_size is not None: assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \ 'global batch size is not divisible by pipeline parallel size when '\ 'using interleaved schedule' # Parameters dtype. args.params_dtype = torch.float Loading
megatron/p2p_communication.py +4 −3 Original line number Diff line number Diff line Loading @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import reduce import operator import torch from megatron import get_args Loading @@ -30,8 +32,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_recv_next = None tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if args.scatter_gather_tensors_in_pipeline: tensor_chunk_shape = ( args.seq_length * args.micro_batch_size * args.hidden_size) // \ tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ mpu.get_tensor_model_parallel_world_size() else: tensor_chunk_shape = tensor_shape Loading