Loading megatron/data/realm_index.py +14 −7 Original line number Diff line number Diff line Loading @@ -7,6 +7,7 @@ import numpy as np import torch from megatron import get_args from megatron import mpu def detach(tensor): Loading Loading @@ -47,8 +48,10 @@ class BlockData(object): def load_from_file(self): """Populate members from instance saved to file""" if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Unpickling BlockData", flush=True) state_dict = pickle.load(open(self.block_data_path, 'rb')) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Finished unpickling BlockData\n", flush=True) self.embed_data = state_dict['embed_data'] Loading Loading @@ -127,6 +130,7 @@ class FaissMIPSIndex(object): except ImportError: raise Exception("Error: Please install faiss to use FaissMIPSIndex") if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Building index", flush=True) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) Loading @@ -138,10 +142,12 @@ class FaissMIPSIndex(object): config.useFloat16 = True self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) else: # CPU index supports IDs so wrap with IDMap self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built Loading @@ -156,7 +162,7 @@ class FaissMIPSIndex(object): if self.block_data is not None: block_data_path = self.block_data.block_data_path del self.block_data self.block_data = BlockData.load_from_file(block_data_path) self.block_data = BlockData(block_data_path) self._set_block_index() Loading @@ -183,6 +189,7 @@ class FaissMIPSIndex(object): else: self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">>> Finished adding block data to index", flush=True) def search_mips_index(self, query_embeds, top_k, reconstruct=True): Loading megatron/indexer.py +5 −4 Original line number Diff line number Diff line Loading @@ -37,7 +37,8 @@ class IndexBuilder(object): model = get_model(lambda: general_ict_model_provider(only_block_model=True)) self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) self.model.eval() self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), self.batch_size)) self.dataset = get_ict_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size)) self.block_data = BlockData(load_from_path=False) def track_and_report_progress(self, batch_size): Loading @@ -58,7 +59,7 @@ class IndexBuilder(object): try: # batch also has query_tokens and query_pad_data _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader) except StopIteration: except (StopIteration, IndexError): break unwrapped_model = self.model Loading @@ -85,6 +86,6 @@ class IndexBuilder(object): # rank 0 process builds the final copy if self.is_main_builder: self.block_data.merge_shards_and_save() # make sure that every single piece of data was embedded assert len(self.block_data.embed_data) == len(self.dataset) self.block_data.clear() megatron/mpu/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -21,6 +21,7 @@ from .data import broadcast_data from .grads import clip_grad_norm from .initialize import is_unitialized from .initialize import destroy_model_parallel from .initialize import get_data_parallel_group from .initialize import get_data_parallel_rank Loading megatron/mpu/initialize.py +5 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None _MPU_RANK = None def is_unitialized(): """Useful for code segments that may be accessed with or without mpu initialization""" return _DATA_PARALLEL_GROUP is None def initialize_model_parallel(model_parallel_size_): """ Initialize model data parallel groups. Loading Loading
megatron/data/realm_index.py +14 −7 Original line number Diff line number Diff line Loading @@ -7,6 +7,7 @@ import numpy as np import torch from megatron import get_args from megatron import mpu def detach(tensor): Loading Loading @@ -47,8 +48,10 @@ class BlockData(object): def load_from_file(self): """Populate members from instance saved to file""" if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Unpickling BlockData", flush=True) state_dict = pickle.load(open(self.block_data_path, 'rb')) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Finished unpickling BlockData\n", flush=True) self.embed_data = state_dict['embed_data'] Loading Loading @@ -127,6 +130,7 @@ class FaissMIPSIndex(object): except ImportError: raise Exception("Error: Please install faiss to use FaissMIPSIndex") if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Building index", flush=True) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) Loading @@ -138,10 +142,12 @@ class FaissMIPSIndex(object): config.useFloat16 = True self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) else: # CPU index supports IDs so wrap with IDMap self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on CPU", flush=True) # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built Loading @@ -156,7 +162,7 @@ class FaissMIPSIndex(object): if self.block_data is not None: block_data_path = self.block_data.block_data_path del self.block_data self.block_data = BlockData.load_from_file(block_data_path) self.block_data = BlockData(block_data_path) self._set_block_index() Loading @@ -183,6 +189,7 @@ class FaissMIPSIndex(object): else: self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">>> Finished adding block data to index", flush=True) def search_mips_index(self, query_embeds, top_k, reconstruct=True): Loading
megatron/indexer.py +5 −4 Original line number Diff line number Diff line Loading @@ -37,7 +37,8 @@ class IndexBuilder(object): model = get_model(lambda: general_ict_model_provider(only_block_model=True)) self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) self.model.eval() self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), self.batch_size)) self.dataset = get_ict_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size)) self.block_data = BlockData(load_from_path=False) def track_and_report_progress(self, batch_size): Loading @@ -58,7 +59,7 @@ class IndexBuilder(object): try: # batch also has query_tokens and query_pad_data _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader) except StopIteration: except (StopIteration, IndexError): break unwrapped_model = self.model Loading @@ -85,6 +86,6 @@ class IndexBuilder(object): # rank 0 process builds the final copy if self.is_main_builder: self.block_data.merge_shards_and_save() # make sure that every single piece of data was embedded assert len(self.block_data.embed_data) == len(self.dataset) self.block_data.clear()
megatron/mpu/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -21,6 +21,7 @@ from .data import broadcast_data from .grads import clip_grad_norm from .initialize import is_unitialized from .initialize import destroy_model_parallel from .initialize import get_data_parallel_group from .initialize import get_data_parallel_rank Loading
megatron/mpu/initialize.py +5 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None _MPU_RANK = None def is_unitialized(): """Useful for code segments that may be accessed with or without mpu initialization""" return _DATA_PARALLEL_GROUP is None def initialize_model_parallel(model_parallel_size_): """ Initialize model data parallel groups. Loading