Commit 9a297541 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'optimize_refactor' into 'main'

Optimizer refactor

See merge request ADLR/megatron-lm!197
parents b69bc7ef 42c6d8e7
Loading
Loading
Loading
Loading
+12 −16
Original line number Diff line number Diff line
@@ -101,11 +101,6 @@ def parse_args(extra_args_provider=None, defaults={},
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0

    # Fp16 loss scaling.
    args.dynamic_loss_scale = False
    if args.loss_scale is None:
        args.dynamic_loss_scale = True

    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
@@ -438,6 +433,18 @@ def _add_mixed_precision_args(parser):

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
    group.add_argument('--apply-query-key-layer-scaling', action='store_true',
@@ -448,21 +455,10 @@ def _add_mixed_precision_args(parser):
                       help='Run attention masking and softmax in fp32.')
    group.add_argument('--fp32-allreduce', action='store_true',
                       help='All-reduce in fp32')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--min-scale', type=float, default=1,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')


    return parser


+5 −1
Original line number Diff line number Diff line
@@ -205,12 +205,16 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
    try:
        state_dict = torch.load(checkpoint_name, map_location='cpu')
    except ModuleNotFoundError:
        from megatron.fp16_deprecated import loss_scaler
        # For backward compatibility.
        print_rank_0(' > deserializing using the old code structure ...')
        sys.modules['fp16.loss_scaler'] = sys.modules[
            'megatron.fp16.loss_scaler']
            'megatron.fp16_deprecated.loss_scaler']
        sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
            'megatron.fp16_deprecated.loss_scaler']
        state_dict = torch.load(checkpoint_name, map_location='cpu')
        sys.modules.pop('fp16.loss_scaler', None)
        sys.modules.pop('megatron.fp16.loss_scaler', None)
    except BaseException:
        print_rank_0('could not load the checkpoint')
        sys.exit()
+0 −2
Original line number Diff line number Diff line
@@ -128,8 +128,6 @@ def get_samples_mapping_(indexed_dataset,
        print_rank_0(' > building sapmles index mapping for {} ...'.format(
            name))
        # First compile and then import.
        from megatron.data.dataset_utils import compile_helper
        compile_helper()
        from megatron.data import helpers
        samples_mapping = helpers.build_mapping(
            indexed_dataset.doc_idx,
+0 −7
Original line number Diff line number Diff line
@@ -49,13 +49,6 @@ class BlendableDataset(torch.utils.data.Dataset):
        self.dataset_index = np.zeros(self.size, dtype=np.uint8)
        self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)

        if torch.distributed.get_rank() == 0:
            from megatron.data.dataset_utils import compile_helper
            compile_helper()
        # Simple barrier
        tmp = torch.cuda.LongTensor([1])
        torch.distributed.all_reduce(tmp, group=mpu.get_data_parallel_group())

        from megatron.data import helpers
        helpers.build_blending_indices(self.dataset_index,
                                       self.dataset_sample_index,
+0 −2
Original line number Diff line number Diff line
@@ -269,8 +269,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
            start_time = time.time()
            # Use C++ implementation for speed.
            # First compile and then import.
            from megatron.data.dataset_utils import compile_helper
            compile_helper()
            from megatron.data import helpers
            assert doc_idx.dtype == np.int32
            assert sizes.dtype == np.int32
Loading