Loading megatron/data/biencoder_dataset_utils.py +0 −27 Original line number Diff line number Diff line Loading @@ -9,33 +9,6 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co from megatron import get_args, get_tokenizer, print_rank_0, mpu def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. assert False, 'DistributedBatchSampler deprecated, change the implementation' from megatron.data.samplers import DistributedBatchSampler batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ict_batch(data_iterator): # Items and their type. keys = ['query_tokens', 'query_mask', Loading pretrain_ict.py +3 −0 Original line number Diff line number Diff line Loading @@ -99,6 +99,9 @@ def forward_step(data_iterator, model, input_tensor): micro_batch_size = query_logits.shape[0] # recall we assert that tensor_model_parallel_size == 1 assert mpu.get_tensor_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" global_batch_size = dist.get_world_size() * micro_batch_size all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) Loading Loading
megatron/data/biencoder_dataset_utils.py +0 −27 Original line number Diff line number Diff line Loading @@ -9,33 +9,6 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co from megatron import get_args, get_tokenizer, print_rank_0, mpu def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers sampler = torch.utils.data.SequentialSampler(dataset) # importantly, drop_last must be False to get all the data. assert False, 'DistributedBatchSampler deprecated, change the implementation' from megatron.data.samplers import DistributedBatchSampler batch_sampler = DistributedBatchSampler(sampler, batch_size=global_batch_size, drop_last=False, rank=rank, world_size=world_size) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ict_batch(data_iterator): # Items and their type. keys = ['query_tokens', 'query_mask', Loading
pretrain_ict.py +3 −0 Original line number Diff line number Diff line Loading @@ -99,6 +99,9 @@ def forward_step(data_iterator, model, input_tensor): micro_batch_size = query_logits.shape[0] # recall we assert that tensor_model_parallel_size == 1 assert mpu.get_tensor_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" global_batch_size = dist.get_world_size() * micro_batch_size all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) Loading