Loading indexer.py +11 −13 Original line number Diff line number Diff line import os import sys import time import torch import torch.distributed as dist from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import get_args, get_adlr_autoresume, print_rank_0 from megatron import get_args from megatron import mpu from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset import ICTDataset from megatron.data.realm_dataset_utils import BlockSampleData from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex from megatron.data.ict_dataset import ICTDataset from megatron.data.realm_index import detach, BlockData from megatron.data.samplers import DistributedBatchSampler from megatron.initialize import initialize_megatron from megatron.training import get_model from pretrain_bert_ict import get_batch, general_ict_model_provider from pretrain_ict import get_batch, general_ict_model_provider def pprint(*args): Loading @@ -30,17 +25,21 @@ class IndexBuilder(object): self.model = None self.dataloader = None self.block_data = None # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint assert not (args.load and args.ict_load) self.using_realm_chkpt = args.ict_load is None self.load_attributes() self.is_main_builder = args.rank == 0 self.iteration = self.total_processed = 0 def load_attributes(self): """Load the necessary attributes: model, dataloader and empty BlockData""" # TODO: handle from_realm_chkpt correctly self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=False) self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) self.model.eval() self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset())) self.block_data = BlockData() self.block_data = BlockData(load_from_path=False) def track_and_report_progress(self, batch_size): """Utility function for tracking progress""" Loading Loading @@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): num_epochs=1, max_num_samples=None, max_seq_length=args.seq_length, short_seq_prob=0.0001, # doesn't matter seed=1, query_in_block_prob=query_in_block_prob, use_titles=use_titles, Loading megatron/data/dataset_utils.py +1 −1 Original line number Diff line number Diff line Loading @@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, num_epochs=None, max_num_samples=train_valid_test_num_samples[index], max_seq_length=max_seq_length, short_seq_prob=short_seq_prob, seed=seed ) Loading @@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, dataset = BertDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, **kwargs ) Loading megatron/data/realm_dataset.pydeleted 100644 → 0 +0 −115 Original line number Diff line number Diff line import collections import itertools import random import numpy as np from torch.utils.data import Dataset from megatron import get_tokenizer from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" def __init__(self, name, block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False): self.name = name self.seed = seed self.max_seq_length = max_seq_length self.query_in_block_prob = query_in_block_prob self.block_dataset = block_dataset self.title_dataset = title_dataset self.short_seq_prob = short_seq_prob self.rng = random.Random(self.seed) self.use_titles = use_titles self.use_one_sent_docs = use_one_sent_docs self.samples_mapping = get_block_samples_mapping( block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, seed, name, use_one_sent_docs) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab self.cls_id = self.tokenizer.cls self.sep_id = self.tokenizer.sep self.mask_id = self.tokenizer.mask self.pad_id = self.tokenizer.pad def __len__(self): return self.samples_mapping.shape[0] def __getitem__(self, idx): """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" sample_data = self.samples_mapping[idx] start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() if self.use_titles: title = self.title_dataset[int(doc_idx)] title_pad_offset = 3 + len(title) else: title = None title_pad_offset = 2 block = [self.block_dataset[i] for i in range(start_idx, end_idx)] assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 # randint() is inclusive for Python rng rand_sent_idx = self.rng.randint(0, len(block) - 1) # keep the query in the context query_in_block_prob fraction of the time. if self.rng.random() < self.query_in_block_prob: query = block[rand_sent_idx].copy() else: query = block.pop(rand_sent_idx) # still need to truncate because blocks are concluded when # the sentence lengths have exceeded max_seq_length. query = query[:self.max_seq_length - 2] block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) block_data = sample_data.as_array() sample = { 'query_tokens': query_tokens, 'query_pad_mask': query_pad_mask, 'block_tokens': block_tokens, 'block_pad_mask': block_pad_mask, 'block_data': block_data, } return sample def get_block(self, start_idx, end_idx, doc_idx): """Get the IDs for an evidence block plus the title of the corresponding document""" block = [self.block_dataset[i] for i in range(start_idx, end_idx)] title = self.title_dataset[int(doc_idx)] block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) return block_tokens, block_pad_mask def get_null_block(self): """Get empty block and title - used in REALM pretraining""" block, title = [], [] block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) return block_tokens, block_pad_mask def concat_and_pad_tokens(self, tokens, title=None): """Concat with special tokens and pad sequence to self.max_seq_length""" tokens = list(tokens) if title is None: tokens = [self.cls_id] + tokens + [self.sep_id] else: title = list(title) tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] assert len(tokens) <= self.max_seq_length num_pad = self.max_seq_length - len(tokens) pad_mask = [1] * len(tokens) + [0] * num_pad tokens += [self.pad_id] * num_pad return np.array(tokens), np.array(pad_mask) megatron/data/realm_index.py +52 −21 Original line number Diff line number Diff line from collections import defaultdict import itertools import os import pickle Loading @@ -8,7 +7,7 @@ import faiss import numpy as np import torch from megatron import get_args, mpu from megatron import get_args def detach(tensor): Loading @@ -17,7 +16,7 @@ def detach(tensor): class BlockData(object): """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM""" def __init__(self, block_data_path=None, rank=None): def __init__(self, block_data_path=None, load_from_path=True, rank=None): self.embed_data = dict() self.meta_data = dict() if block_data_path is None: Loading @@ -27,6 +26,9 @@ class BlockData(object): self.block_data_path = block_data_path self.rank = rank if load_from_path: self.load_from_file() block_data_name = os.path.splitext(self.block_data_path)[0] self.temp_dir_name = block_data_name + '_tmp' Loading @@ -43,18 +45,23 @@ class BlockData(object): """ self.embed_data = dict() @classmethod def load_from_file(cls, fname): def load_from_file(self): """Populate members from instance saved to file""" print("\n> Unpickling BlockData", flush=True) state_dict = pickle.load(open(fname, 'rb')) state_dict = pickle.load(open(self.block_data_path, 'rb')) print(">> Finished unpickling BlockData\n", flush=True) new_index = cls() new_index.embed_data = state_dict['embed_data'] new_index.meta_data = state_dict['meta_data'] return new_index self.embed_data = state_dict['embed_data'] self.meta_data = state_dict['meta_data'] def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False): """Add data for set of blocks :param block_indices: 1D array of unique int ids for the blocks :param block_embeds: 2D array of embeddings of the blocks :param block_metas: 2D array of metadata for the blocks. In the case of REALM this will be [start_idx, end_idx, doc_idx] """ for idx, embed, meta in zip(block_indices, block_embeds, block_metas): if not allow_overwrite and idx in self.embed_data: raise ValueError("Unexpectedly tried to overwrite block data") Loading @@ -63,6 +70,7 @@ class BlockData(object): self.meta_data[idx] = meta def save_shard(self): """Save the block data that was created this in this process""" if not os.path.isdir(self.temp_dir_name): os.makedirs(self.temp_dir_name, exist_ok=True) Loading Loading @@ -104,9 +112,9 @@ class BlockData(object): class FaissMIPSIndex(object): """Wrapper object for a BlockData which similarity search via FAISS under the hood""" def __init__(self, index_type, embed_size, use_gpu=False): self.index_type = index_type def __init__(self, embed_size, block_data=None, use_gpu=False): self.embed_size = embed_size self.block_data = block_data self.use_gpu = use_gpu self.id_map = dict() Loading @@ -114,10 +122,7 @@ class FaissMIPSIndex(object): self._set_block_index() def _set_block_index(self): INDEX_TYPES = ['flat_ip'] if self.index_type not in INDEX_TYPES: raise ValueError("Invalid index type specified") """Create a Faiss Flat index with inner product as the metric to search against""" print("\n> Building index", flush=True) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) Loading @@ -129,29 +134,52 @@ class FaissMIPSIndex(object): config.useFloat16 = True self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True) print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) else: # CPU index supports IDs so wrap with IDMap self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) print(">> Finished building index\n", flush=True) print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built if self.block_data is not None: self.add_block_embed_data(self.block_data) def reset_index(self): """Delete existing index and create anew""" del self.block_mips_index # reset the block data so that _set_block_index will reload it as well if self.block_data is not None: block_data_path = self.block_data.block_data_path del self.block_data self.block_data = BlockData.load_from_file(block_data_path) self._set_block_index() def add_block_embed_data(self, all_block_data): """Add the embedding of each block to the underlying FAISS index""" # this assumes the embed_data is a dict : {int: np.array<float>} block_indices, block_embeds = zip(*all_block_data.embed_data.items()) # the embeddings have to be entered in as float32 even though the math internally is done with float16. block_embeds_arr = np.float32(np.array(block_embeds)) block_indices_arr = np.array(block_indices) # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with if self.use_gpu: for i, idx in enumerate(block_indices): self.id_map[i] = idx # we no longer need the embedding data since it's in the index now all_block_data.clear() if self.use_gpu: self.block_mips_index.add(np.float32(np.array(block_embeds))) self.block_mips_index.add(block_embeds_arr) else: self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices)) self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) print(">>> Finished adding block data to index", flush=True) def search_mips_index(self, query_embeds, top_k, reconstruct=True): """Get the top-k blocks by the index distance metric. Loading @@ -160,12 +188,15 @@ class FaissMIPSIndex(object): if False: return [num_queries x k] array of distances, and another for indices """ query_embeds = np.float32(detach(query_embeds)) with torch.no_grad(): if reconstruct: # get the vectors themselves top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) return top_k_block_embeds else: # get distances and indices of closest vectors distances, block_indices = self.block_mips_index.search(query_embeds, top_k) if self.use_gpu: fresh_indices = np.zeros(block_indices.shape) Loading megatron/tokenizer/tokenizer.py +5 −5 Original line number Diff line number Diff line Loading @@ -31,10 +31,10 @@ def build_tokenizer(args): # Select and instantiate the tokenizer. assert args.vocab_file is not None if args.tokenizer_type == 'BertWordPieceLowerCase': tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, lower_case=True) elif args.tokenizer_type == 'BertWordPieceCase': tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, lower_case=False) elif args.tokenizer_type == 'GPT2BPETokenizer': assert args.merge_file is not None Loading Loading @@ -124,7 +124,7 @@ class AbstractTokenizer(ABC): 'tokenizer'.format(self.name)) class BertWordPieceTokenizer(AbstractTokenizer): class _BertWordPieceTokenizer(AbstractTokenizer): """Original BERT wordpiece tokenizer.""" def __init__(self, vocab_file, lower_case=True): Loading Loading
indexer.py +11 −13 Original line number Diff line number Diff line import os import sys import time import torch import torch.distributed as dist from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import get_args, get_adlr_autoresume, print_rank_0 from megatron import get_args from megatron import mpu from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset import ICTDataset from megatron.data.realm_dataset_utils import BlockSampleData from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex from megatron.data.ict_dataset import ICTDataset from megatron.data.realm_index import detach, BlockData from megatron.data.samplers import DistributedBatchSampler from megatron.initialize import initialize_megatron from megatron.training import get_model from pretrain_bert_ict import get_batch, general_ict_model_provider from pretrain_ict import get_batch, general_ict_model_provider def pprint(*args): Loading @@ -30,17 +25,21 @@ class IndexBuilder(object): self.model = None self.dataloader = None self.block_data = None # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint assert not (args.load and args.ict_load) self.using_realm_chkpt = args.ict_load is None self.load_attributes() self.is_main_builder = args.rank == 0 self.iteration = self.total_processed = 0 def load_attributes(self): """Load the necessary attributes: model, dataloader and empty BlockData""" # TODO: handle from_realm_chkpt correctly self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=False) self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) self.model.eval() self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset())) self.block_data = BlockData() self.block_data = BlockData(load_from_path=False) def track_and_report_progress(self, batch_size): """Utility function for tracking progress""" Loading Loading @@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): num_epochs=1, max_num_samples=None, max_seq_length=args.seq_length, short_seq_prob=0.0001, # doesn't matter seed=1, query_in_block_prob=query_in_block_prob, use_titles=use_titles, Loading
megatron/data/dataset_utils.py +1 −1 Original line number Diff line number Diff line Loading @@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, num_epochs=None, max_num_samples=train_valid_test_num_samples[index], max_seq_length=max_seq_length, short_seq_prob=short_seq_prob, seed=seed ) Loading @@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, dataset = BertDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, **kwargs ) Loading
megatron/data/realm_dataset.pydeleted 100644 → 0 +0 −115 Original line number Diff line number Diff line import collections import itertools import random import numpy as np from torch.utils.data import Dataset from megatron import get_tokenizer from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" def __init__(self, name, block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False): self.name = name self.seed = seed self.max_seq_length = max_seq_length self.query_in_block_prob = query_in_block_prob self.block_dataset = block_dataset self.title_dataset = title_dataset self.short_seq_prob = short_seq_prob self.rng = random.Random(self.seed) self.use_titles = use_titles self.use_one_sent_docs = use_one_sent_docs self.samples_mapping = get_block_samples_mapping( block_dataset, title_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, seed, name, use_one_sent_docs) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab self.cls_id = self.tokenizer.cls self.sep_id = self.tokenizer.sep self.mask_id = self.tokenizer.mask self.pad_id = self.tokenizer.pad def __len__(self): return self.samples_mapping.shape[0] def __getitem__(self, idx): """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" sample_data = self.samples_mapping[idx] start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() if self.use_titles: title = self.title_dataset[int(doc_idx)] title_pad_offset = 3 + len(title) else: title = None title_pad_offset = 2 block = [self.block_dataset[i] for i in range(start_idx, end_idx)] assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 # randint() is inclusive for Python rng rand_sent_idx = self.rng.randint(0, len(block) - 1) # keep the query in the context query_in_block_prob fraction of the time. if self.rng.random() < self.query_in_block_prob: query = block[rand_sent_idx].copy() else: query = block.pop(rand_sent_idx) # still need to truncate because blocks are concluded when # the sentence lengths have exceeded max_seq_length. query = query[:self.max_seq_length - 2] block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) block_data = sample_data.as_array() sample = { 'query_tokens': query_tokens, 'query_pad_mask': query_pad_mask, 'block_tokens': block_tokens, 'block_pad_mask': block_pad_mask, 'block_data': block_data, } return sample def get_block(self, start_idx, end_idx, doc_idx): """Get the IDs for an evidence block plus the title of the corresponding document""" block = [self.block_dataset[i] for i in range(start_idx, end_idx)] title = self.title_dataset[int(doc_idx)] block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) return block_tokens, block_pad_mask def get_null_block(self): """Get empty block and title - used in REALM pretraining""" block, title = [], [] block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) return block_tokens, block_pad_mask def concat_and_pad_tokens(self, tokens, title=None): """Concat with special tokens and pad sequence to self.max_seq_length""" tokens = list(tokens) if title is None: tokens = [self.cls_id] + tokens + [self.sep_id] else: title = list(title) tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] assert len(tokens) <= self.max_seq_length num_pad = self.max_seq_length - len(tokens) pad_mask = [1] * len(tokens) + [0] * num_pad tokens += [self.pad_id] * num_pad return np.array(tokens), np.array(pad_mask)
megatron/data/realm_index.py +52 −21 Original line number Diff line number Diff line from collections import defaultdict import itertools import os import pickle Loading @@ -8,7 +7,7 @@ import faiss import numpy as np import torch from megatron import get_args, mpu from megatron import get_args def detach(tensor): Loading @@ -17,7 +16,7 @@ def detach(tensor): class BlockData(object): """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM""" def __init__(self, block_data_path=None, rank=None): def __init__(self, block_data_path=None, load_from_path=True, rank=None): self.embed_data = dict() self.meta_data = dict() if block_data_path is None: Loading @@ -27,6 +26,9 @@ class BlockData(object): self.block_data_path = block_data_path self.rank = rank if load_from_path: self.load_from_file() block_data_name = os.path.splitext(self.block_data_path)[0] self.temp_dir_name = block_data_name + '_tmp' Loading @@ -43,18 +45,23 @@ class BlockData(object): """ self.embed_data = dict() @classmethod def load_from_file(cls, fname): def load_from_file(self): """Populate members from instance saved to file""" print("\n> Unpickling BlockData", flush=True) state_dict = pickle.load(open(fname, 'rb')) state_dict = pickle.load(open(self.block_data_path, 'rb')) print(">> Finished unpickling BlockData\n", flush=True) new_index = cls() new_index.embed_data = state_dict['embed_data'] new_index.meta_data = state_dict['meta_data'] return new_index self.embed_data = state_dict['embed_data'] self.meta_data = state_dict['meta_data'] def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False): """Add data for set of blocks :param block_indices: 1D array of unique int ids for the blocks :param block_embeds: 2D array of embeddings of the blocks :param block_metas: 2D array of metadata for the blocks. In the case of REALM this will be [start_idx, end_idx, doc_idx] """ for idx, embed, meta in zip(block_indices, block_embeds, block_metas): if not allow_overwrite and idx in self.embed_data: raise ValueError("Unexpectedly tried to overwrite block data") Loading @@ -63,6 +70,7 @@ class BlockData(object): self.meta_data[idx] = meta def save_shard(self): """Save the block data that was created this in this process""" if not os.path.isdir(self.temp_dir_name): os.makedirs(self.temp_dir_name, exist_ok=True) Loading Loading @@ -104,9 +112,9 @@ class BlockData(object): class FaissMIPSIndex(object): """Wrapper object for a BlockData which similarity search via FAISS under the hood""" def __init__(self, index_type, embed_size, use_gpu=False): self.index_type = index_type def __init__(self, embed_size, block_data=None, use_gpu=False): self.embed_size = embed_size self.block_data = block_data self.use_gpu = use_gpu self.id_map = dict() Loading @@ -114,10 +122,7 @@ class FaissMIPSIndex(object): self._set_block_index() def _set_block_index(self): INDEX_TYPES = ['flat_ip'] if self.index_type not in INDEX_TYPES: raise ValueError("Invalid index type specified") """Create a Faiss Flat index with inner product as the metric to search against""" print("\n> Building index", flush=True) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) Loading @@ -129,29 +134,52 @@ class FaissMIPSIndex(object): config.useFloat16 = True self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True) print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) else: # CPU index supports IDs so wrap with IDMap self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) print(">> Finished building index\n", flush=True) print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built if self.block_data is not None: self.add_block_embed_data(self.block_data) def reset_index(self): """Delete existing index and create anew""" del self.block_mips_index # reset the block data so that _set_block_index will reload it as well if self.block_data is not None: block_data_path = self.block_data.block_data_path del self.block_data self.block_data = BlockData.load_from_file(block_data_path) self._set_block_index() def add_block_embed_data(self, all_block_data): """Add the embedding of each block to the underlying FAISS index""" # this assumes the embed_data is a dict : {int: np.array<float>} block_indices, block_embeds = zip(*all_block_data.embed_data.items()) # the embeddings have to be entered in as float32 even though the math internally is done with float16. block_embeds_arr = np.float32(np.array(block_embeds)) block_indices_arr = np.array(block_indices) # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with if self.use_gpu: for i, idx in enumerate(block_indices): self.id_map[i] = idx # we no longer need the embedding data since it's in the index now all_block_data.clear() if self.use_gpu: self.block_mips_index.add(np.float32(np.array(block_embeds))) self.block_mips_index.add(block_embeds_arr) else: self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices)) self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) print(">>> Finished adding block data to index", flush=True) def search_mips_index(self, query_embeds, top_k, reconstruct=True): """Get the top-k blocks by the index distance metric. Loading @@ -160,12 +188,15 @@ class FaissMIPSIndex(object): if False: return [num_queries x k] array of distances, and another for indices """ query_embeds = np.float32(detach(query_embeds)) with torch.no_grad(): if reconstruct: # get the vectors themselves top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) return top_k_block_embeds else: # get distances and indices of closest vectors distances, block_indices = self.block_mips_index.search(query_embeds, top_k) if self.use_gpu: fresh_indices = np.zeros(block_indices.shape) Loading
megatron/tokenizer/tokenizer.py +5 −5 Original line number Diff line number Diff line Loading @@ -31,10 +31,10 @@ def build_tokenizer(args): # Select and instantiate the tokenizer. assert args.vocab_file is not None if args.tokenizer_type == 'BertWordPieceLowerCase': tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, lower_case=True) elif args.tokenizer_type == 'BertWordPieceCase': tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, lower_case=False) elif args.tokenizer_type == 'GPT2BPETokenizer': assert args.merge_file is not None Loading Loading @@ -124,7 +124,7 @@ class AbstractTokenizer(ABC): 'tokenizer'.format(self.name)) class BertWordPieceTokenizer(AbstractTokenizer): class _BertWordPieceTokenizer(AbstractTokenizer): """Original BERT wordpiece tokenizer.""" def __init__(self, vocab_file, lower_case=True): Loading