Commit e8fb052f authored by mshoeybi's avatar mshoeybi
Browse files

made contiguous buffer in local ddp default

parent df6e3cd7
Loading
Loading
Loading
Loading
+7 −11
Original line number Diff line number Diff line
@@ -148,16 +148,11 @@ def parse_args(extra_args_provider=None, defaults={},
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

    # If we do accumulation and all-reduces in fp32, we need to have
    # local DDP and we should set the use-contiguous-buffers-in-ddp.
    # If we do accumulation and all-reduces in fp32, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
    if args.accumulate_allreduce_grads_in_fp32:
        assert args.DDP_impl == 'local'
        args.use_contiguous_buffers_in_ddp = True

    # If we use a contiguous buffer to hold main grads, we need to have
    # local DDP.
    if args.use_contiguous_buffers_in_ddp:
        assert args.DDP_impl == 'local'
        assert args.use_contiguous_buffers_in_local_ddp

    if args.dataloader_type is None:
        args.dataloader_type = 'single'
@@ -584,9 +579,10 @@ def _add_distributed_args(parser):
                       choices=['local', 'torch'],
                       help='which DistributedDataParallel implementation '
                       'to use.')
    group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
                       help='If set, use contiguous buffer in DDP. Note that '
                       'this option only works woth local DDP.' )
    group.add_argument('--no-contiguous-buffers-in-local-ddp',
                       action='store_false', help='If set, dont use '
                       'contiguous buffer in local DDP.',
                       dest='use_contiguous_buffers_in_local_ddp')
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
+2 −2
Original line number Diff line number Diff line
@@ -100,7 +100,7 @@ def get_megatron_optimizer(model):
                                                 args.clip_grad,
                                                 args.log_num_zeros_in_grad,
                                                 params_have_main_grad,
                                                 args.use_contiguous_buffers_in_ddp,
                                                 args.use_contiguous_buffers_in_local_ddp,
                                                 args.bf16,
                                                 grad_scaler)

@@ -108,4 +108,4 @@ def get_megatron_optimizer(model):
    return FP32Optimizer(optimizer, args.clip_grad,
                         args.log_num_zeros_in_grad,
                         params_have_main_grad,
                         args.use_contiguous_buffers_in_ddp)
                         args.use_contiguous_buffers_in_local_ddp)
+10 −10
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ class MegatronOptimizer(ABC):
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
                 params_have_main_grad,
                 use_contiguous_buffers_in_ddp):
                 use_contiguous_buffers_in_local_ddp):

        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
@@ -78,9 +78,9 @@ class MegatronOptimizer(ABC):
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
        self.use_contiguous_buffers_in_ddp = use_contiguous_buffers_in_ddp
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp

        if self.use_contiguous_buffers_in_ddp:
        if self.use_contiguous_buffers_in_local_ddp:
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"

@@ -193,12 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_ddp,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
                 bf16, grad_scaler):

        super(Float16OptimizerWithFloat16Params, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_ddp)
            params_have_main_grad, use_contiguous_buffers_in_local_ddp)

        self.bf16 = bf16
        self.grad_scaler = grad_scaler
@@ -323,7 +323,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
                # persist and therefore should not be deallocated.)
                model_param.grad = None
                if self.params_have_main_grad and \
                   not self.use_contiguous_buffers_in_ddp:
                   not self.use_contiguous_buffers_in_local_ddp:
                    model_param.main_grad = None

        # For fp32 grads, we need to reset the grads to main grad.
@@ -335,7 +335,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
                    if not self.use_contiguous_buffers_in_ddp:
                    if not self.use_contiguous_buffers_in_local_ddp:
                        model_param.main_grad = None

    def _unscale_main_grads_and_check_for_nan(self):
@@ -491,11 +491,11 @@ class FP32Optimizer(MegatronOptimizer):
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
                 params_have_main_grad,
                 use_contiguous_buffers_in_ddp):
                 use_contiguous_buffers_in_local_ddp):

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_ddp)
            params_have_main_grad, use_contiguous_buffers_in_local_ddp)

        self._scale = torch.cuda.FloatTensor([1.0])

@@ -525,7 +525,7 @@ class FP32Optimizer(MegatronOptimizer):
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
                    if not self.use_contiguous_buffers_in_ddp:
                    if not self.use_contiguous_buffers_in_local_ddp:
                        param.main_grad = None

        # Clip gradients.
+2 −2
Original line number Diff line number Diff line
@@ -253,7 +253,7 @@ def get_model(model_provider_func):
    if args.DDP_impl == 'local':
        model = [LocalDDP(model_module,
                          args.accumulate_allreduce_grads_in_fp32,
                          args.use_contiguous_buffers_in_ddp)
                          args.use_contiguous_buffers_in_local_ddp)
                 for model_module in model]
        return model

@@ -351,7 +351,7 @@ def train_step(forward_step_func, data_iterator,
    timers = get_timers()

    # Set grad to zero.
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
        for partition in model:
            partition.zero_grad_buffer()
    optimizer.zero_grad()