Commit dcef9069 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Change argument to control the number of model chunks in a stage

parent 5489bda9
Loading
Loading
Loading
Loading
+12 −4
Original line number Diff line number Diff line
@@ -116,10 +116,18 @@ def parse_args(extra_args_provider=None, defaults={},
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
    if args.virtual_pipeline_model_parallel_size is not None:
    if args.num_layers_per_virtual_pipeline_stage is not None:
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
            (args.num_layers // args.pipeline_model_parallel_size) // \
            args.num_layers_per_virtual_pipeline_stage
        assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
            'global batch size is not divisible by pipeline parallel size when ' \
            'using interleaved schedule'
    else:
        args.virtual_pipeline_model_parallel_size = None

    # Parameters dtype.
    args.params_dtype = torch.float
@@ -561,8 +569,8 @@ def _add_distributed_args(parser):
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
    group.add_argument('--virtual-pipeline-model-parallel-size', type=int, default=None,
                       help='Number of virtual pipeline stages in physical stage.')
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')