Loading megatron/arguments.py +12 −4 Original line number Diff line number Diff line Loading @@ -116,10 +116,18 @@ def parse_args(extra_args_provider=None, defaults={}, print('setting global batch size to {}'.format( args.global_batch_size), flush=True) assert args.global_batch_size > 0 if args.virtual_pipeline_model_parallel_size is not None: if args.num_layers_per_virtual_pipeline_stage is not None: assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ 'number of layers is not divisible by number of layers per virtual ' \ 'pipeline stage' args.virtual_pipeline_model_parallel_size = \ (args.num_layers // args.pipeline_model_parallel_size) // \ args.num_layers_per_virtual_pipeline_stage assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \ 'global batch size is not divisible by pipeline parallel size when ' \ 'using interleaved schedule' else: args.virtual_pipeline_model_parallel_size = None # Parameters dtype. args.params_dtype = torch.float Loading Loading @@ -561,8 +569,8 @@ def _add_distributed_args(parser): group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') group.add_argument('--virtual-pipeline-model-parallel-size', type=int, default=None, help='Number of virtual pipeline stages in physical stage.') group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') Loading Loading
megatron/arguments.py +12 −4 Original line number Diff line number Diff line Loading @@ -116,10 +116,18 @@ def parse_args(extra_args_provider=None, defaults={}, print('setting global batch size to {}'.format( args.global_batch_size), flush=True) assert args.global_batch_size > 0 if args.virtual_pipeline_model_parallel_size is not None: if args.num_layers_per_virtual_pipeline_stage is not None: assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ 'number of layers is not divisible by number of layers per virtual ' \ 'pipeline stage' args.virtual_pipeline_model_parallel_size = \ (args.num_layers // args.pipeline_model_parallel_size) // \ args.num_layers_per_virtual_pipeline_stage assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \ 'global batch size is not divisible by pipeline parallel size when ' \ 'using interleaved schedule' else: args.virtual_pipeline_model_parallel_size = None # Parameters dtype. args.params_dtype = torch.float Loading Loading @@ -561,8 +569,8 @@ def _add_distributed_args(parser): group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') group.add_argument('--virtual-pipeline-model-parallel-size', type=int, default=None, help='Number of virtual pipeline stages in physical stage.') group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') Loading