Commit 8eff2a99 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

remove the function get_one_epoch_dataloader and also added assert

parent 38898931
Loading
Loading
Loading
Loading
+0 −27
Original line number Diff line number Diff line
@@ -9,33 +9,6 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
from megatron import get_args, get_tokenizer, print_rank_0, mpu


def get_one_epoch_dataloader(dataset, micro_batch_size=None):
    """Specifically one epoch to be used in an indexing job."""
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    if micro_batch_size is None:
        micro_batch_size = args.micro_batch_size
    global_batch_size = micro_batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    # importantly, drop_last must be False to get all the data.
    assert False, 'DistributedBatchSampler deprecated, change the implementation'
    from megatron.data.samplers import DistributedBatchSampler
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=False,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


def get_ict_batch(data_iterator):
    # Items and their type.
    keys = ['query_tokens', 'query_mask',
+3 −0
Original line number Diff line number Diff line
@@ -99,6 +99,9 @@ def forward_step(data_iterator, model, input_tensor):

    micro_batch_size = query_logits.shape[0]
    # recall we assert that tensor_model_parallel_size == 1
    assert mpu.get_tensor_model_parallel_world_size() == 1, \
        "Model parallel size > 1 not supported for ICT"

    global_batch_size = dist.get_world_size() * micro_batch_size
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)