Commit 87bd026a authored by Raul Puri's avatar Raul Puri
Browse files

Update initialize.py

parent 5448ca25
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -59,7 +59,7 @@ def _initialize_distributed():
    """Initialize torch.distributed and mpu."""
    args = get_args()

    device_count = 0
    device_count = torch.cuda.device_count()
    if torch.distributed.is_initialized():

        if args.rank == 0:
@@ -69,7 +69,7 @@ def _initialize_distributed():
        args.world_size = torch.distributed.get_world_size()
        if device_count > 0:
            device = torch.cuda.current_device()
            local_rank = args.rank % torch.cuda.device_count()
            local_rank = args.rank % device_count
            assert local_rank == device, \
                'expected local-rank to be the same as rank % device-count.'

@@ -79,7 +79,7 @@ def _initialize_distributed():
            print('> initializing torch distributed ...', flush=True)
        # Manually set the device ids.
        if device_count > 0:
            device = args.rank % torch.cuda.device_count()
            device = args.rank % device_count
            if args.local_rank is not None:
                assert args.local_rank == device, \
                    'expected local-rank to be the same as rank % device-count.'