Commit b9bd1a11 authored by Neel Kant's avatar Neel Kant
Browse files

Additional refactoring

parent ca0cdfaa
Loading
Loading
Loading
Loading
+11 −13
Original line number Diff line number Diff line
import os
import sys
import time

import torch
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args, get_adlr_autoresume, print_rank_0
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_dataset_utils import BlockSampleData
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.data.ict_dataset import ICTDataset
from megatron.data.realm_index import detach, BlockData
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import get_batch, general_ict_model_provider
from pretrain_ict import get_batch, general_ict_model_provider


def pprint(*args):
@@ -30,17 +25,21 @@ class IndexBuilder(object):
        self.model = None
        self.dataloader = None
        self.block_data = None

        # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
        assert not (args.load and args.ict_load)
        self.using_realm_chkpt = args.ict_load is None

        self.load_attributes()
        self.is_main_builder = args.rank == 0
        self.iteration = self.total_processed = 0

    def load_attributes(self):
        """Load the necessary attributes: model, dataloader and empty BlockData"""
        # TODO: handle from_realm_chkpt correctly
        self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=False)
        self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
        self.block_data = BlockData()
        self.block_data = BlockData(load_from_path=False)

    def track_and_report_progress(self, batch_size):
        """Utility function for tracking progress"""
@@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
        num_epochs=1,
        max_num_samples=None,
        max_seq_length=args.seq_length,
        short_seq_prob=0.0001,  # doesn't matter
        seed=1,
        query_in_block_prob=query_in_block_prob,
        use_titles=use_titles,
+1 −1
Original line number Diff line number Diff line
@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                num_epochs=None,
                max_num_samples=train_valid_test_num_samples[index],
                max_seq_length=max_seq_length,
                short_seq_prob=short_seq_prob,
                seed=seed
            )

@@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                dataset = BertDataset(
                    indexed_dataset=indexed_dataset,
                    masked_lm_prob=masked_lm_prob,
                    short_seq_prob=short_seq_prob,
                    **kwargs
                )

megatron/data/realm_dataset.py

deleted100644 → 0
+0 −115
Original line number Diff line number Diff line
import collections
import itertools
import random

import numpy as np
from torch.utils.data import Dataset

from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list


class ICTDataset(Dataset):
    """Dataset containing sentences and their blocks for an inverse cloze task."""
    def __init__(self, name, block_dataset, title_dataset, data_prefix,
                 num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
                 short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
        self.name = name
        self.seed = seed
        self.max_seq_length = max_seq_length
        self.query_in_block_prob = query_in_block_prob
        self.block_dataset = block_dataset
        self.title_dataset = title_dataset
        self.short_seq_prob = short_seq_prob
        self.rng = random.Random(self.seed)
        self.use_titles = use_titles
        self.use_one_sent_docs = use_one_sent_docs

        self.samples_mapping = get_block_samples_mapping(
            block_dataset, title_dataset, data_prefix, num_epochs,
            max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
        self.tokenizer = get_tokenizer()
        self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
        self.vocab_id_to_token_list = self.tokenizer.inv_vocab
        self.cls_id = self.tokenizer.cls
        self.sep_id = self.tokenizer.sep
        self.mask_id = self.tokenizer.mask
        self.pad_id = self.tokenizer.pad

    def __len__(self):
        return self.samples_mapping.shape[0]

    def __getitem__(self, idx):
        """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
        sample_data = self.samples_mapping[idx]
        start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()

        if self.use_titles:
            title = self.title_dataset[int(doc_idx)]
            title_pad_offset = 3 + len(title)
        else:
            title = None
            title_pad_offset = 2
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
        assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1

        # randint() is inclusive for Python rng
        rand_sent_idx = self.rng.randint(0, len(block) - 1)

        # keep the query in the context query_in_block_prob fraction of the time.
        if self.rng.random() < self.query_in_block_prob:
            query = block[rand_sent_idx].copy()
        else:
            query = block.pop(rand_sent_idx)

        # still need to truncate because blocks are concluded when
        # the sentence lengths have exceeded max_seq_length.
        query = query[:self.max_seq_length - 2]
        block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]

        query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
        block_data = sample_data.as_array()

        sample = {
            'query_tokens': query_tokens,
            'query_pad_mask': query_pad_mask,
            'block_tokens': block_tokens,
            'block_pad_mask': block_pad_mask,
            'block_data': block_data,
        }

        return sample

    def get_block(self, start_idx, end_idx, doc_idx):
        """Get the IDs for an evidence block plus the title of the corresponding document"""
        block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
        title = self.title_dataset[int(doc_idx)]

        block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

        return block_tokens, block_pad_mask

    def get_null_block(self):
        """Get empty block and title - used in REALM pretraining"""
        block, title = [], []
        block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)

        return block_tokens, block_pad_mask

    def concat_and_pad_tokens(self, tokens, title=None):
        """Concat with special tokens and pad sequence to self.max_seq_length"""
        tokens = list(tokens)
        if title is None:
            tokens = [self.cls_id] + tokens + [self.sep_id]
        else:
            title = list(title)
            tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
        assert len(tokens) <= self.max_seq_length

        num_pad = self.max_seq_length - len(tokens)
        pad_mask = [1] * len(tokens) + [0] * num_pad
        tokens += [self.pad_id] * num_pad

        return np.array(tokens), np.array(pad_mask)
+52 −21
Original line number Diff line number Diff line
from collections import defaultdict
import itertools
import os
import pickle
@@ -8,7 +7,7 @@ import faiss
import numpy as np
import torch

from megatron import get_args, mpu
from megatron import get_args


def detach(tensor):
@@ -17,7 +16,7 @@ def detach(tensor):

class BlockData(object):
    """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
    def __init__(self, block_data_path=None, rank=None):
    def __init__(self, block_data_path=None, load_from_path=True, rank=None):
        self.embed_data = dict()
        self.meta_data = dict()
        if block_data_path is None:
@@ -27,6 +26,9 @@ class BlockData(object):
        self.block_data_path = block_data_path
        self.rank = rank

        if load_from_path:
            self.load_from_file()

        block_data_name = os.path.splitext(self.block_data_path)[0]
        self.temp_dir_name = block_data_name + '_tmp'

@@ -43,18 +45,23 @@ class BlockData(object):
        """
        self.embed_data = dict()

    @classmethod
    def load_from_file(cls, fname):
    def load_from_file(self):
        """Populate members from instance saved to file"""

        print("\n> Unpickling BlockData", flush=True)
        state_dict = pickle.load(open(fname, 'rb'))
        state_dict = pickle.load(open(self.block_data_path, 'rb'))
        print(">> Finished unpickling BlockData\n", flush=True)

        new_index = cls()
        new_index.embed_data = state_dict['embed_data']
        new_index.meta_data = state_dict['meta_data']
        return new_index
        self.embed_data = state_dict['embed_data']
        self.meta_data = state_dict['meta_data']

    def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
        """Add data for set of blocks
        :param block_indices: 1D array of unique int ids for the blocks
        :param block_embeds: 2D array of embeddings of the blocks
        :param block_metas: 2D array of metadata for the blocks.
            In the case of REALM this will be [start_idx, end_idx, doc_idx]
        """
        for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
            if not allow_overwrite and idx in self.embed_data:
                raise ValueError("Unexpectedly tried to overwrite block data")
@@ -63,6 +70,7 @@ class BlockData(object):
            self.meta_data[idx] = meta

    def save_shard(self):
        """Save the block data that was created this in this process"""
        if not os.path.isdir(self.temp_dir_name):
            os.makedirs(self.temp_dir_name, exist_ok=True)

@@ -104,9 +112,9 @@ class BlockData(object):

class FaissMIPSIndex(object):
    """Wrapper object for a BlockData which similarity search via FAISS under the hood"""
    def __init__(self, index_type, embed_size, use_gpu=False):
        self.index_type = index_type
    def __init__(self, embed_size, block_data=None, use_gpu=False):
        self.embed_size = embed_size
        self.block_data = block_data
        self.use_gpu = use_gpu
        self.id_map = dict()

@@ -114,10 +122,7 @@ class FaissMIPSIndex(object):
        self._set_block_index()

    def _set_block_index(self):
        INDEX_TYPES = ['flat_ip']
        if self.index_type not in INDEX_TYPES:
            raise ValueError("Invalid index type specified")

        """Create a Faiss Flat index with inner product as the metric to search against"""
        print("\n> Building index", flush=True)
        self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)

@@ -129,29 +134,52 @@ class FaissMIPSIndex(object):
            config.useFloat16 = True

            self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
            print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
            print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
        else:
            # CPU index supports IDs so wrap with IDMap
            self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
            print(">> Finished building index\n", flush=True)
            print(">> Initialized index on CPU", flush=True)

        # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
        if self.block_data is not None:
            self.add_block_embed_data(self.block_data)

    def reset_index(self):
        """Delete existing index and create anew"""
        del self.block_mips_index

        # reset the block data so that _set_block_index will reload it as well
        if self.block_data is not None:
            block_data_path = self.block_data.block_data_path
            del self.block_data
            self.block_data = BlockData.load_from_file(block_data_path)

        self._set_block_index()

    def add_block_embed_data(self, all_block_data):
        """Add the embedding of each block to the underlying FAISS index"""

        # this assumes the embed_data is a dict : {int: np.array<float>}
        block_indices, block_embeds = zip(*all_block_data.embed_data.items())

        # the embeddings have to be entered in as float32 even though the math internally is done with float16.
        block_embeds_arr = np.float32(np.array(block_embeds))
        block_indices_arr = np.array(block_indices)

        # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
        if self.use_gpu:
            for i, idx in enumerate(block_indices):
                self.id_map[i] = idx

        # we no longer need the embedding data since it's in the index now
        all_block_data.clear()

        if self.use_gpu:
            self.block_mips_index.add(np.float32(np.array(block_embeds)))
            self.block_mips_index.add(block_embeds_arr)
        else:
            self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
            self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)

        print(">>> Finished adding block data to index", flush=True)

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
        """Get the top-k blocks by the index distance metric.
@@ -160,12 +188,15 @@ class FaissMIPSIndex(object):
                            if False: return [num_queries x k] array of distances, and another for indices
        """
        query_embeds = np.float32(detach(query_embeds))

        with torch.no_grad():

            if reconstruct:
                # get the vectors themselves
                top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
                return top_k_block_embeds

            else:
                # get distances and indices of closest vectors
                distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
                if self.use_gpu:
                    fresh_indices = np.zeros(block_indices.shape)
+5 −5
Original line number Diff line number Diff line
@@ -31,10 +31,10 @@ def build_tokenizer(args):
    # Select and instantiate the tokenizer.
    assert args.vocab_file is not None
    if args.tokenizer_type == 'BertWordPieceLowerCase':
        tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file,
        tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
                                            lower_case=True)
    elif args.tokenizer_type == 'BertWordPieceCase':
        tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file,
        tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
                                            lower_case=False)
    elif args.tokenizer_type == 'GPT2BPETokenizer':
        assert args.merge_file is not None
@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
                                  'tokenizer'.format(self.name))


class BertWordPieceTokenizer(AbstractTokenizer):
class _BertWordPieceTokenizer(AbstractTokenizer):
    """Original BERT wordpiece tokenizer."""

    def __init__(self, vocab_file, lower_case=True):
Loading