Commit 7890681a authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'staging_distributed_cpu_fix' into 'staging'

Added logic for initilialize_megatron to work with cpu only jobs. This is...

See merge request ADLR/megatron-lm!46
parents 1363e503 87bd026a
Loading
Loading
Loading
Loading
+18 −13
Original line number Diff line number Diff line
@@ -59,6 +59,7 @@ def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

    device_count = torch.cuda.device_count()
    if torch.distributed.is_initialized():

        if args.rank == 0:
@@ -66,8 +67,9 @@ def _initialize_distributed():
                  'skipping initialization ...', flush=True)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()
        if device_count > 0:
            device = torch.cuda.current_device()
        local_rank = args.rank % torch.cuda.device_count()
            local_rank = args.rank % device_count
            assert local_rank == device, \
                'expected local-rank to be the same as rank % device-count.'

@@ -76,7 +78,8 @@ def _initialize_distributed():
        if args.rank == 0:
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
        device = args.rank % torch.cuda.device_count()
        if device_count > 0:
            device = args.rank % device_count
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'
@@ -94,6 +97,7 @@ def _initialize_distributed():
            init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    if device_count > 0:
        mpu.initialize_model_parallel(args.model_parallel_size)


@@ -112,6 +116,7 @@ def _set_random_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.device_count() > 0:
            mpu.model_parallel_cuda_manual_seed(seed)
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))