Loading examples/create_embeddings.sh 0 → 100644 +32 −0 Original line number Diff line number Diff line #!/bin/bash # Compute embeddings for each entry of a given dataset (e.g. Wikipedia) RANK=0 WORLD_SIZE=1 # Wikipedia data can be downloaded from the following link: # https://github.com/facebookresearch/DPR/blob/master/data/download_data.py EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset> EMBEDDING_PATH=<Specify path to store embeddings> CHECKPOINT_PATH=<Specify path of pretrained ICT model> python tools/create_doc_index.py \ --num-layers 12 \ --hidden-size 768 \ --num-attention-heads 12 \ --tensor-model-parallel-size 1 \ --micro-batch-size 128 \ --checkpoint-activations \ --seq-length 512 \ --retriever-seq-length 256 \ --max-position-embeddings 512 \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ --indexer-log-interval 1000 \ --indexer-batch-size 128 \ --vocab-file bert-vocab.txt \ --num-workers 2 \ --fp16 megatron/arguments.py +11 −0 Original line number Diff line number Diff line Loading @@ -629,6 +629,12 @@ def _add_data_args(parser): 'This should be exclusive of --seq-length') group.add_argument('--decoder-seq-length', type=int, default=None, help="Maximum decoder sequence length to process.") group.add_argument('--retriever-seq-length', type=int, default=256, help='Maximum sequence length for the biencoder model ' ' for retriever') group.add_argument('--sample-rate', type=float, default=1.0, help='sample rate for training data. Supposed to be 0 ' ' < sample_rate < 1') group.add_argument('--mask-prob', type=float, default=0.15, help='Probability of replacing a token with mask.') group.add_argument('--short-seq-prob', type=float, default=0.1, Loading Loading @@ -698,6 +704,8 @@ def _add_biencoder_args(parser): 'ICT dataset') group.add_argument('--use-one-sent-docs', action='store_true', help='Whether to use one sentence documents in ICT') group.add_argument('--evidence-data-path', type=str, default=None, help='Path to Wikipedia Evidence frm DPR paper') # training group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, Loading @@ -712,6 +720,9 @@ def _add_biencoder_args(parser): help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') group.add_argument('--embedding-path', type=str, default=None, help='Where to save/load Open-Retrieval Embedding' ' data to/from') # indexer group.add_argument('--indexer-batch-size', type=int, default=128, Loading megatron/checkpointing.py +14 −12 Original line number Diff line number Diff line Loading @@ -383,40 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True return iteration def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ selectively load retrieval models for indexing/retrieving from saved checkpoints """ args = get_args() model = utils.unwrap_model(model) load_path = args.load if from_realm_chkpt else args.ict_load load_path = custom_load_path if custom_load_path is not None else args.load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) # assert iteration > 0 checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] ret_state_dict = state_dict['model'] if only_query_model: ict_state_dict.pop('context_model') if only_block_model: ict_state_dict.pop('question_model') ret_state_dict.pop('context_model') if only_context_model: ret_state_dict.pop('query_model') model.load_state_dict(ict_state_dict) assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model megatron/data/biencoder_dataset_utils.py +43 −3 Original line number Diff line number Diff line Loading @@ -4,9 +4,49 @@ import time import numpy as np import torch from megatron import mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, \ pad_and_convert_to_numpy from megatron.data.data_samplers import MegatronPretrainingSampler def make_attention_mask(source_block, target_block): """ Returns a 2-dimensional (2-D) attention mask :param source_block: 1-D array :param target_block: 1-D array """ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) mask = mask.astype(np.int64) # (source_length, target_length) return mask def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers # Use megatron's sampler with consumed samples set to 0 as # this is only for evaluation and don't intend to resume half way. # Also, set the drop last to false as don't intend to remove # the last batch batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=0, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), drop_last=False) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ict_batch(data_iterator): Loading megatron/data/data_samplers.py +14 −4 Original line number Diff line number Diff line Loading @@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): class MegatronPretrainingSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, drop_last=True): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples Loading @@ -65,6 +65,7 @@ class MegatronPretrainingSampler: self.data_parallel_rank = data_parallel_rank self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.drop_last = drop_last # Sanity checks. assert self.total_samples > 0, \ Loading @@ -81,17 +82,26 @@ class MegatronPretrainingSampler: def __len__(self): return self.total_samples def get_start_end_idx(self): start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size return start_idx, end_idx def __iter__(self): batch = [] # Last batch if not complete will be dropped. # Last batch will be dropped if drop_last is not set False for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] batch = [] # Check the last partial batch and see drop_last is set if len(batch) > 0 and not self.drop_last: start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] class MegatronPretrainingRandomSampler: Loading Loading
examples/create_embeddings.sh 0 → 100644 +32 −0 Original line number Diff line number Diff line #!/bin/bash # Compute embeddings for each entry of a given dataset (e.g. Wikipedia) RANK=0 WORLD_SIZE=1 # Wikipedia data can be downloaded from the following link: # https://github.com/facebookresearch/DPR/blob/master/data/download_data.py EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset> EMBEDDING_PATH=<Specify path to store embeddings> CHECKPOINT_PATH=<Specify path of pretrained ICT model> python tools/create_doc_index.py \ --num-layers 12 \ --hidden-size 768 \ --num-attention-heads 12 \ --tensor-model-parallel-size 1 \ --micro-batch-size 128 \ --checkpoint-activations \ --seq-length 512 \ --retriever-seq-length 256 \ --max-position-embeddings 512 \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ --indexer-log-interval 1000 \ --indexer-batch-size 128 \ --vocab-file bert-vocab.txt \ --num-workers 2 \ --fp16
megatron/arguments.py +11 −0 Original line number Diff line number Diff line Loading @@ -629,6 +629,12 @@ def _add_data_args(parser): 'This should be exclusive of --seq-length') group.add_argument('--decoder-seq-length', type=int, default=None, help="Maximum decoder sequence length to process.") group.add_argument('--retriever-seq-length', type=int, default=256, help='Maximum sequence length for the biencoder model ' ' for retriever') group.add_argument('--sample-rate', type=float, default=1.0, help='sample rate for training data. Supposed to be 0 ' ' < sample_rate < 1') group.add_argument('--mask-prob', type=float, default=0.15, help='Probability of replacing a token with mask.') group.add_argument('--short-seq-prob', type=float, default=0.1, Loading Loading @@ -698,6 +704,8 @@ def _add_biencoder_args(parser): 'ICT dataset') group.add_argument('--use-one-sent-docs', action='store_true', help='Whether to use one sentence documents in ICT') group.add_argument('--evidence-data-path', type=str, default=None, help='Path to Wikipedia Evidence frm DPR paper') # training group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, Loading @@ -712,6 +720,9 @@ def _add_biencoder_args(parser): help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') group.add_argument('--embedding-path', type=str, default=None, help='Where to save/load Open-Retrieval Embedding' ' data to/from') # indexer group.add_argument('--indexer-batch-size', type=int, default=128, Loading
megatron/checkpointing.py +14 −12 Original line number Diff line number Diff line Loading @@ -383,40 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True return iteration def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ selectively load retrieval models for indexing/retrieving from saved checkpoints """ args = get_args() model = utils.unwrap_model(model) load_path = args.load if from_realm_chkpt else args.ict_load load_path = custom_load_path if custom_load_path is not None else args.load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) # assert iteration > 0 checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ict_state_dict = state_dict['model'] if from_realm_chkpt and mpu.get_data_parallel_rank() == 0: print(" loading ICT state dict from REALM", flush=True) ict_state_dict = ict_state_dict['retriever']['ict_model'] ret_state_dict = state_dict['model'] if only_query_model: ict_state_dict.pop('context_model') if only_block_model: ict_state_dict.pop('question_model') ret_state_dict.pop('context_model') if only_context_model: ret_state_dict.pop('query_model') model.load_state_dict(ict_state_dict) assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model
megatron/data/biencoder_dataset_utils.py +43 −3 Original line number Diff line number Diff line Loading @@ -4,9 +4,49 @@ import time import numpy as np import torch from megatron import mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, mpu, print_rank_0 from megatron.data.dataset_utils import create_masked_lm_predictions, \ pad_and_convert_to_numpy from megatron.data.data_samplers import MegatronPretrainingSampler def make_attention_mask(source_block, target_block): """ Returns a 2-dimensional (2-D) attention mask :param source_block: 1-D array :param target_block: 1-D array """ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) mask = mask.astype(np.int64) # (source_length, target_length) return mask def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() world_size = mpu.get_data_parallel_world_size() rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size global_batch_size = micro_batch_size * world_size num_workers = args.num_workers # Use megatron's sampler with consumed samples set to 0 as # this is only for evaluation and don't intend to resume half way. # Also, set the drop last to false as don't intend to remove # the last batch batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=0, micro_batch_size=args.micro_batch_size, data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size(), drop_last=False) return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) def get_ict_batch(data_iterator): Loading
megatron/data/data_samplers.py +14 −4 Original line number Diff line number Diff line Loading @@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): class MegatronPretrainingSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, drop_last=True): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples Loading @@ -65,6 +65,7 @@ class MegatronPretrainingSampler: self.data_parallel_rank = data_parallel_rank self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.drop_last = drop_last # Sanity checks. assert self.total_samples > 0, \ Loading @@ -81,17 +82,26 @@ class MegatronPretrainingSampler: def __len__(self): return self.total_samples def get_start_end_idx(self): start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size return start_idx, end_idx def __iter__(self): batch = [] # Last batch if not complete will be dropped. # Last batch will be dropped if drop_last is not set False for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] batch = [] # Check the last partial batch and see drop_last is set if len(batch) > 0 and not self.drop_last: start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] class MegatronPretrainingRandomSampler: Loading