Commit a7ee77ea authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

flag for data parallel random initialization

parent fd8dd9c0
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -518,6 +518,9 @@ def _add_initialization_args(parser):
    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
    group.add_argument('--data-parallel-random-init', action='store_true',
                       help='Enable random initialization of params '
                       'across data parallel ranks')
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
+7 −2
Original line number Diff line number Diff line
@@ -142,6 +142,7 @@ def read_metadata(tracker_filename):

def get_rng_state():
    """ collect rng state across data parallel ranks """
    args = get_args()
    rng_state = {
        'random_rng_state': random.getstate(),
        'np_rng_state': np.random.get_state(),
@@ -151,7 +152,8 @@ def get_rng_state():

    rng_state_list = None
    if torch.distributed.is_initialized() and \
            mpu.get_data_parallel_world_size() > 1:
            mpu.get_data_parallel_world_size() > 1 and \
            args.data_parallel_random_init:
        if mpu.get_data_parallel_rank() == 0:
            rng_state_list = \
                [None for i in range(mpu.get_data_parallel_world_size())]
@@ -407,7 +409,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
        try:
            if 'rng_state' in state_dict:
                # access rng_state for data parallel rank
                if args.data_parallel_random_init:
                    rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
                else:
                    rng_state = state_dict['rng_state'][0]
                random.setstate(rng_state['random_rng_state'])
                np.random.set_state(rng_state['np_rng_state'])
                torch.set_rng_state(rng_state['torch_rng_state'])
+7 −4
Original line number Diff line number Diff line
@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
        # Random seeds for reproducibility.
        if args.rank == 0:
            print('> setting random seeds to {} ...'.format(args.seed))
        _set_random_seed(args.seed)
        _set_random_seed(args.seed, args.data_parallel_random_init)

    # Set pytorch JIT layer fusion options.
    _set_jit_fusion_options()
@@ -203,11 +203,14 @@ def _init_autoresume():
        torch.distributed.barrier()


def _set_random_seed(seed_):
def _set_random_seed(seed_, data_parallel_random_init=False):
    """Set random seed for reproducability."""
    if seed_ is not None and seed_ > 0:
        # Ensure that different pipeline MP stages and different data parallel ranks get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) + (10 * mpu.get_data_parallel_rank())
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
+4 −2
Original line number Diff line number Diff line
@@ -285,6 +285,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
            # broad cast params from data parallel src rank to other data parallel ranks
            if args.data_parallel_random_init:
                for model_module in model:
                    model_module.broadcast_params()
        else: