Loading megatron/training.py +6 −4 Original line number Diff line number Diff line Loading @@ -195,21 +195,23 @@ def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() # Build model. if mpu.get_pipeline_model_parallel_world_size() > 1 and \ args.virtual_pipeline_model_parallel_size is not None: model = [] for i in range(args.virtual_pipeline_model_parallel_size): mpu.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() this_model = model_provider_func( pre_process=pre_process, post_process=post_process ) model.append(this_model) else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() model = model_provider_func( pre_process=pre_process, post_process=post_process Loading Loading
megatron/training.py +6 −4 Original line number Diff line number Diff line Loading @@ -195,21 +195,23 @@ def get_model(model_provider_func): """Build the model.""" args = get_args() # Build model on cpu. pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() # Build model. if mpu.get_pipeline_model_parallel_world_size() > 1 and \ args.virtual_pipeline_model_parallel_size is not None: model = [] for i in range(args.virtual_pipeline_model_parallel_size): mpu.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() this_model = model_provider_func( pre_process=pre_process, post_process=post_process ) model.append(this_model) else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() model = model_provider_func( pre_process=pre_process, post_process=post_process Loading