Loading megatron/initialize.py +18 −13 Original line number Diff line number Diff line Loading @@ -59,6 +59,7 @@ def _initialize_distributed(): """Initialize torch.distributed and mpu.""" args = get_args() device_count = 0 if torch.distributed.is_initialized(): if args.rank == 0: Loading @@ -66,6 +67,7 @@ def _initialize_distributed(): 'skipping initialization ...', flush=True) args.rank = torch.distributed.get_rank() args.world_size = torch.distributed.get_world_size() if device_count > 0: device = torch.cuda.current_device() local_rank = args.rank % torch.cuda.device_count() assert local_rank == device, \ Loading @@ -76,6 +78,7 @@ def _initialize_distributed(): if args.rank == 0: print('> initializing torch distributed ...', flush=True) # Manually set the device ids. if device_count > 0: device = args.rank % torch.cuda.device_count() if args.local_rank is not None: assert args.local_rank == device, \ Loading @@ -94,6 +97,7 @@ def _initialize_distributed(): init_method=init_method) # Set the model-parallel / data-parallel communicators. if device_count > 0: mpu.initialize_model_parallel(args.model_parallel_size) Loading @@ -112,6 +116,7 @@ def _set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.device_count() > 0: mpu.model_parallel_cuda_manual_seed(seed) else: raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) Loading Loading
megatron/initialize.py +18 −13 Original line number Diff line number Diff line Loading @@ -59,6 +59,7 @@ def _initialize_distributed(): """Initialize torch.distributed and mpu.""" args = get_args() device_count = 0 if torch.distributed.is_initialized(): if args.rank == 0: Loading @@ -66,6 +67,7 @@ def _initialize_distributed(): 'skipping initialization ...', flush=True) args.rank = torch.distributed.get_rank() args.world_size = torch.distributed.get_world_size() if device_count > 0: device = torch.cuda.current_device() local_rank = args.rank % torch.cuda.device_count() assert local_rank == device, \ Loading @@ -76,6 +78,7 @@ def _initialize_distributed(): if args.rank == 0: print('> initializing torch distributed ...', flush=True) # Manually set the device ids. if device_count > 0: device = args.rank % torch.cuda.device_count() if args.local_rank is not None: assert args.local_rank == device, \ Loading @@ -94,6 +97,7 @@ def _initialize_distributed(): init_method=init_method) # Set the model-parallel / data-parallel communicators. if device_count > 0: mpu.initialize_model_parallel(args.model_parallel_size) Loading @@ -112,6 +116,7 @@ def _set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.device_count() > 0: mpu.model_parallel_cuda_manual_seed(seed) else: raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) Loading