Commit f0a445fa authored by mohammad's avatar mohammad
Browse files

added consumed tokens to checkpoints and some refactoring

parent 4311b695
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -72,6 +72,9 @@ def parse_args(extra_args_provider=None, defaults={},
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0

    # Set input defaults.
    for key in defaults:
+13 −2
Original line number Diff line number Diff line
@@ -89,7 +89,8 @@ def get_checkpoint_tracker_filename(checkpoints_path):
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def save_checkpoint(iteration, model, optimizer, lr_scheduler,
                    consumed_train_samples=None, consumed_valid_samples=None):
    """Save a model checkpoint."""
    args = get_args()

@@ -103,6 +104,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 2.0
        state_dict['iteration'] = iteration
        if consumed_train_samples:
            state_dict['consumed_train_samples'] = consumed_train_samples
        if consumed_valid_samples:
            state_dict['consumed_valid_samples'] = consumed_valid_samples
        state_dict['model'] = model.state_dict_for_save_checkpoint()

        # Optimizer stuff.
@@ -214,6 +219,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
                                 checkpoint_name))
                sys.exit()

    if 'consumed_train_samples' in state_dict:
        assert args.consumed_train_samples == 0
        args.consumed_train_samples = state_dict['consumed_train_samples']
    if 'consumed_valid_samples' in state_dict:
        assert args.consumed_valid_samples == 0
        args.consumed_valid_samples = state_dict['consumed_valid_samples']

    # Check arguments.
    if 'args' in state_dict:
+32 −1
Original line number Diff line number Diff line
@@ -13,7 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatorn Sampler."""
"""Dataloaders."""


import torch

from megatron import get_args
from megatron import mpu


def build_pretraining_data_loader(dataset, consumed_samples):
    """Buld dataloader given an input dataset."""

    if dataset is None:
        return None
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    global_batch_size = args.batch_size * world_size

    # Megatron sampler
    batch_sampler = MegatronPretrainingSampler(
        total_samples=len(dataset),
        consumed_samples=consumed_samples,
        global_batch_size=global_batch_size,
        rank=mpu.get_data_parallel_rank(),
        world_size=world_size)

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=args.num_workers,
                                       pin_memory=True)


class MegatronPretrainingSampler:
+29 −13
Original line number Diff line number Diff line
@@ -37,7 +37,7 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory


@@ -104,7 +104,9 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
                                   iteration, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
        save_checkpoint(iteration, model, optimizer, lr_scheduler,
                        consumed_train_samples=args.consumed_train_samples,
                        consumed_valid_samples=args.consumed_valid_samples)

    if args.do_test:
        # Run on test data.
@@ -224,7 +226,8 @@ def setup_model_and_optimizer(model_provider_func):
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

    if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
    if args.iteration == 0 and hasattr(unwrapped_model,
                                       'init_state_dict_from_bert'):
        print("Initializing ICT from pretrained BERT model", flush=True)
        unwrapped_model.init_state_dict_from_bert()

@@ -414,6 +417,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                                             optimizer,
                                             lr_scheduler)
        iteration += 1
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
                                       args.batch_size

        # Logging.
        loss_scale = None
@@ -433,7 +438,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
        # Checkpointing
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
            save_checkpoint(iteration, model, optimizer, lr_scheduler,
                            consumed_train_samples=args.consumed_train_samples,
                            consumed_valid_samples=args.consumed_valid_samples)

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
@@ -472,6 +479,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
                                                            args.eval_iters))
            # Forward evaluation.
            _, loss_dict = forward_step_func(data_iterator, model)
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
                                           * args.batch_size
            # Reduce across processes.
            for key in loss_dict:
                total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
@@ -517,11 +526,19 @@ def build_train_valid_test_data_iterators(
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        # Rank, size, and global batch size.

    # Rank and  global batch size.
    data_parallel_size = mpu.get_data_parallel_world_size()
    global_batch_size = args.batch_size * data_parallel_size
    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
        args.consumed_train_samples = args.iteration * global_batch_size
    if args.iteration > 0 and args.consumed_valid_samples == 0:
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
            args.eval_iters * global_batch_size
    
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:

        # Number of train/valid/test samples.
        train_iters = args.train_iters
@@ -540,12 +557,11 @@ def build_train_valid_test_data_iterators(
            train_val_test_num_samples)

        # Build dataloders.
        comsumed_samples = args.iteration * global_batch_size
        train_dataloader = make_data_loader(train_ds, comsumed_samples)
        comsumed_samples = (args.iteration // args.eval_interval) * \
            args.eval_iters * global_batch_size
        valid_dataloader = make_data_loader(valid_ds, comsumed_samples)
        test_dataloader = make_data_loader(test_ds, comsumed_samples)
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
+0 −25
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.data.sampler import MegatronPretrainingSampler
from megatron.fp16 import FP16_Optimizer


@@ -89,30 +88,6 @@ def check_adlr_autoresume_termination(iteration, model,
        sys.exit(0)


def make_data_loader(dataset, consumed_samples):
    """Buld dataloader given an input dataset."""
    if dataset is None:
        return None
    args = get_args()

    # Data parallel arguments.
    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    # Megatron sampler
    batch_sampler = MegatronPretrainingSampler(
        total_samples=len(dataset), consumed_samples=consumed_samples,
        global_batch_size=global_batch_size, rank=rank, world_size=world_size)

    # Torch dataloader.
    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,