Commit 41a64613 authored by Rewon Child's avatar Rewon Child
Browse files

Merge main

parents 8676baca 83d26f03
Loading
Loading
Loading
Loading
+32 −0
Original line number Diff line number Diff line
#!/bin/bash

# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)

RANK=0
WORLD_SIZE=1

# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path to store embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>

python tools/create_doc_index.py \
    --num-layers 12 \
    --hidden-size 768 \
    --num-attention-heads 12 \
    --tensor-model-parallel-size 1 \
    --micro-batch-size 128 \
    --checkpoint-activations \
    --seq-length 512 \
    --retriever-seq-length 256 \
    --max-position-embeddings 512 \
    --load ${CHECKPOINT_PATH} \
    --evidence-data-path ${EVIDENCE_DATA_DIR} \
    --embedding-path ${EMBEDDING_PATH} \
    --indexer-log-interval 1000 \
    --indexer-batch-size 128 \
    --vocab-file bert-vocab.txt \
    --num-workers 2 \
    --fp16
+44 −0
Original line number Diff line number Diff line
#! /bin/bash

# Runs the "217M" parameter biencoder model for ICT retriever

RANK=0
WORLD_SIZE=1

PRETRAINED_BERT_PATH=<Specify path of pretrained BERT model>
TEXT_DATA_PATH=<Specify path and file prefix of the text data>
TITLE_DATA_PATH=<Specify path and file prefix od the titles>
CHECKPOINT_PATH=<Specify path>


python pretrain_ict.py \
        --num-layers 12 \
        --hidden-size 768 \
        --num-attention-heads 12 \
        --tensor-model-parallel-size 1 \
        --micro-batch-size 32 \
        --seq-length 256 \
        --max-position-embeddings 512 \
        --train-iters 100000 \
        --vocab-file bert-vocab.txt \
        --tokenizer-type BertWordPieceLowerCase \
        --DDP-impl torch \
        --bert-load ${PRETRAINED_BERT_PATH} \
        --log-interval 100 \
        --eval-interval 1000 \
        --eval-iters 10 \
        --retriever-report-topk-accuracies 1 5 10 20 100 \
        --retriever-score-scaling \
        --load $CHECKPOINT_PATH \
        --save $CHECKPOINT_PATH \
        --data-path ${TEXT_DATA_PATH} \
        --titles-data-path ${TITLE_DATA_PATH} \
        --lr 0.0001 \
        --lr-decay-style linear \
        --weight-decay 1e-2 \
        --clip-grad 1.0 \
        --lr-warmup-fraction 0.01 \
        --save-interval 4000 \
        --exit-interval 8000 \
        --query-in-block-prob 0.1 \
        --fp16
+43 −8
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
    parser = _add_realm_args(parser)
    parser = _add_biencoder_args(parser)
    parser = _add_vit_args(parser)
    parser = _add_logging_args(parser)

@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
        ' divisible by tensor parallel size ({}) times pipeline paralle ' \
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
    args.data_parallel_size = args.world_size // model_parallel_size
@@ -116,6 +116,15 @@ def parse_args(extra_args_provider=None, defaults={},
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
    if args.num_layers_per_virtual_pipeline_stage is not None:
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
            (args.num_layers // args.pipeline_model_parallel_size) // \
            args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None

    # Parameters dtype.
    args.params_dtype = torch.float
@@ -214,7 +223,7 @@ def parse_args(extra_args_provider=None, defaults={},
    custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
        seq_len % 4 == 0 and attn_batch_size % 4 == 0

    if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion:
    if not (args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion):
        print('WARNING: constraints for invoking optimized'
            ' fused softmax kernel are not met. We default back to unfused'
            ' kernel invocations.')
@@ -559,6 +568,8 @@ def _add_distributed_args(parser):
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
@@ -566,6 +577,9 @@ def _add_distributed_args(parser):
                       choices=['local', 'torch'],
                       help='which DistributedDataParallel implementation '
                       'to use.')
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
@@ -617,6 +631,12 @@ def _add_data_args(parser):
                       'This should be exclusive of --seq-length')
    group.add_argument('--decoder-seq-length', type=int, default=None,
                       help="Maximum decoder sequence length to process.")
    group.add_argument('--retriever-seq-length', type=int, default=256,
                       help='Maximum sequence length for the biencoder model '
                        ' for retriever')
    group.add_argument('--sample-rate', type=float, default=1.0,
                       help='sample rate for training data. Supposed to be 0 '
                            ' < sample_rate < 1')
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
@@ -657,13 +677,19 @@ def _add_autoresume_args(parser):
    return parser


def _add_realm_args(parser):
    group = parser.add_argument_group(title='realm')
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
                       help='Size of block embeddings to be used in ICT and '
                        'REALM (paper default: 128)')
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
                        help='Whether to share the parameters of the query '
                        'and context models or not')

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
@@ -680,16 +706,25 @@ def _add_realm_args(parser):
                       'ICT dataset')
    group.add_argument('--use-one-sent-docs', action='store_true',
                       help='Whether to use one sentence documents in ICT')
    group.add_argument('--evidence-data-path', type=str, default=None,
                       help='Path to Wikipedia Evidence frm DPR paper')

    # training
    group.add_argument('--report-topk-accuracies', nargs='+', default=[],
                       help="Which top-k accuracies to report (e.g. '1 5 20')")
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
    group.add_argument('--retriever-score-scaling', action='store_true',
                       help='Whether to scale retriever scores by inverse '
                        'square root of hidden size')

    # faiss index
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--block-data-path', type=str, default=None,
                       help='Where to save/load BlockData to/from')
    group.add_argument('--embedding-path', type=str, default=None,
                       help='Where to save/load Open-Retrieval Embedding'
                        ' data to/from')

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
+63 −44
Original line number Diff line number Diff line
@@ -21,12 +21,12 @@ import sys
import numpy as np

import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP

from megatron import (get_args,
                      mpu,
                      print_rank_0,
                      update_num_microbatches)
                      update_num_microbatches,
                      utils)

_CHECKPOINT_VERSION = None

@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 3.0
        state_dict['iteration'] = iteration
        state_dict['model'] = model.state_dict_for_save_checkpoint()
        if len(model) == 1:
            state_dict['model'] = model[0].state_dict_for_save_checkpoint()
        else:
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()

        # Optimizer stuff.
        if not args.no_save_optim:
@@ -202,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):

    return t

def fix_query_key_value_ordering(model, checkpoint_version):
    """Fix up query/key/value matrix ordering if checkpoint
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
        print_rank_0(" succesfully fixed query-key-values ordering for"
                    " checkpoint version {}".format(checkpoint_version))

def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
    """Load a model checkpoint and return the iteration.
    strict (bool): whether to strictly enforce that the keys in
@@ -211,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

@@ -297,30 +328,17 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
        print_rank_0('could not find arguments in the checkpoint ...')

    # Model.
    model.load_state_dict(state_dict['model'], strict=strict)
    if len(model) == 1:
        model[0].load_state_dict(state_dict['model'], strict=strict)
    else:
        for i in range(len(model)):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
            model[i].load_state_dict(state_dict['model%d' % i], strict=strict)

    # Fix up query/key/value matrix ordering
    if get_checkpoint_version() < 2.0:
    # Fix up query/key/value matrix ordering if needed
    checkpoint_version = get_checkpoint_version()
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 3, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 3, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
            if name.endswith(('.key_value.weight', '.key_value.bias')):
                if checkpoint_version == 0:
                    fixed_param = _transpose_first_dim(param.data, 2, True, model)
                elif checkpoint_version == 1.0:
                    fixed_param = _transpose_first_dim(param.data, 2, False, model)
                else:
                    print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
                    sys.exit()
                param.data.copy_(fixed_param)
    print_rank_0(f' checkpoint version {checkpoint_version}')
    fix_query_key_value_ordering(model, checkpoint_version)

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
@@ -365,41 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    return iteration


def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False):
    """selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
def load_biencoder_checkpoint(model, only_query_model=False,
        only_context_model=False, custom_load_path=None):
    """
    selectively load retrieval models for indexing/retrieving 
    from saved checkpoints
    """

    args = get_args()

    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    load_path = args.load if from_realm_chkpt else args.ict_load
    load_path = custom_load_path if custom_load_path is not None else args.load

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

    # assert iteration > 0
    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
    ict_state_dict = state_dict['model']
    if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
        print(" loading ICT state dict from REALM", flush=True)
        ict_state_dict = ict_state_dict['retriever']['ict_model']
    ret_state_dict = state_dict['model']

    if only_query_model:
        ict_state_dict.pop('context_model')
    if only_block_model:
        ict_state_dict.pop('question_model')
        ret_state_dict.pop('context_model')
    if only_context_model:
        ret_state_dict.pop('query_model')

    model.load_state_dict(ict_state_dict)
    assert len(model) == 1
    model[0].load_state_dict(ret_state_dict)
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model
+211 −0
Original line number Diff line number Diff line
import os
import time

import numpy as np
import torch

from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, \
                                            pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler

def make_attention_mask(source_block, target_block):
    """
    Returns a 2-dimensional (2-D) attention mask
    :param source_block: 1-D array
    :param target_block: 1-D array
    """
    mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
    mask = mask.astype(np.int64)
    # (source_length, target_length)
    return mask

def get_one_epoch_dataloader(dataset, micro_batch_size=None):
    """Specifically one epoch to be used in an indexing job."""
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    if micro_batch_size is None:
        micro_batch_size = args.micro_batch_size
    global_batch_size = micro_batch_size * world_size
    num_workers = args.num_workers

    # Use megatron's sampler with consumed samples set to 0 as
    # this is only for evaluation and don't intend to resume half way.
    # Also, set the drop last to false as don't intend to remove
    # the last batch
    batch_sampler = MegatronPretrainingSampler(
        total_samples=len(dataset),
        consumed_samples=0,
        micro_batch_size=args.micro_batch_size,
        data_parallel_rank=mpu.get_data_parallel_rank(),
        data_parallel_size=mpu.get_data_parallel_world_size(),
        drop_last=False)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


def get_ict_batch(data_iterator):
    # Items and their type.
    keys = ['query_tokens', 'query_mask',
            'context_tokens', 'context_mask', 'block_data']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is None:
        data = None
    else:
        data = next(data_iterator)
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    query_tokens = data_b['query_tokens'].long()
    query_mask = data_b['query_mask'] < 0.5
    context_tokens = data_b['context_tokens'].long()
    context_mask = data_b['context_mask'] < 0.5
    block_indices = data_b['block_data'].long()

    return query_tokens, query_mask,\
           context_tokens, context_mask, block_indices


def join_str_list(str_list):
    """Join a list of strings, handling spaces appropriately"""
    result = ""
    for s in str_list:
        if s.startswith("##"):
            result += s[2:]
        else:
            result += " " + s
    return result


class BlockSampleData(object):
    """A struct for fully describing a fixed-size block of data as used in REALM

    :param start_idx: for first sentence of the block
    :param end_idx: for last sentence of the block (may be partially truncated in sample construction)
    :param doc_idx: the index of the document from which the block comes in the original indexed dataset
    :param block_idx: a unique integer identifier given to every block.
    """
    def __init__(self, start_idx, end_idx, doc_idx, block_idx):
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.doc_idx = doc_idx
        self.block_idx = block_idx

    def as_array(self):
        return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)

    def as_tuple(self):
        return self.start_idx, self.end_idx, self.doc_idx, self.block_idx


class BlockSamplesMapping(object):
    def __init__(self, mapping_array):
        # make sure that the array is compatible with BlockSampleData
        assert mapping_array.shape[1] == 4
        self.mapping_array = mapping_array

    def __len__(self):
        return self.mapping_array.shape[0]

    def __getitem__(self, idx):
        """Get the data associated with an indexed sample."""
        sample_data = BlockSampleData(*self.mapping_array[idx])
        return sample_data


def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
                              max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
    """Get samples mapping for a dataset over fixed size blocks. This function also requires
    a dataset of the titles for the source documents since their lengths must be taken into account.

    :return: samples_mapping (BlockSamplesMapping)
    """

    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{}s'.format(seed)
    if use_one_sent_docs:
        indexmap_filename += '_1sentok'
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
    if mpu.get_data_parallel_rank() == 0 and \
            not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert block_dataset.doc_idx.dtype == np.int64
        assert block_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building samples index mapping for {} ...'.format(
            name))

        from megatron.data import helpers
        mapping_array = helpers.build_blocks_mapping(
            block_dataset.doc_idx,
            block_dataset.sizes,
            title_dataset.sizes,
            num_epochs,
            max_num_samples,
            max_seq_length - 3,  # account for added tokens
            seed,
            verbose,
            use_one_sent_docs)


        print_rank_0(' > done building samples index mapping')
        np.save(indexmap_filename, mapping_array, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elapsed time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
            time.time() - start_time))

    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
    assert counts[0].item() == torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()

    mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
    samples_mapping = BlockSamplesMapping(mapping_array)

    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        mapping_array.shape[0]))

    return samples_mapping
Loading