Commit 598d7ee2 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge_dpr' into 'main'

DPR evaluation hangs and Readme

See merge request ADLR/megatron-lm!280
parents 83c4d95a 98113c69
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -29,7 +29,6 @@ python tasks/main.py \
    --retriever-seq-length 256 \
    --vocab-file  bert-vocab.txt\
    --qa-data-test ${QA_FILE} \
    --num-workers 2 \
    --faiss-use-gpu \
    --retriever-report-topk-accuracies 1 5 20 100 \
    --fp16 \
+1 −1
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
        --bert-load ${BERT_LOAD_PATH} \
        --save-interval 5000 \
        --log-interval 10 \
        --eval-interval 25000 \
        --eval-interval 20000 \
        --eval-iters 100 \
        --indexer-log-interval 1000 \
        --faiss-use-gpu \

tasks/orqa/README.md

0 → 100644
+32 −0
Original line number Diff line number Diff line
## End-to-End Training of Neural Retrievers for Open-Domain Question Answering

Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).

### Unsupervised pretraining
1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT) task, which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.

<pre>
python tools/preprocess_data.py \
    --input /path/to/corpus.json \
    --json-keys text title \
    --split-sentences \
    --tokenizer-type BertWordPieceLowerCase \
    --vocab-file /path/to/vocab.txt \
    --output-prefix corpus_indexed \
    --workers 10
</pre>

2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model with a batch size of 4096 (hence the need for a data parallel world size of 32).

3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for natural question answering dataset.

### Supervised finetuning

1. Use the above pretrained ICT model to finetune using [Google's natural question answering dataset](https://ai.google.com/research/NaturalQuestions/). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to do this. Our finetuning consists of score scaling, longer training (80 epochs), and hard negative examples.

2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.

More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).

The reader component will be available soon.
 
+32 −1
Original line number Diff line number Diff line
@@ -33,6 +33,28 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator

# input_ is a 2D tensor
def check_and_append_tensor_for_gather(group, rank, world_size, input_):

    # gather the size of the first dimension of the tensor from all ranks
    current_length = input_.size()[0]
    first_dim = torch.tensor([[current_length]], 
        device=torch.cuda.current_device())
    input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
    input_list[rank].copy_(first_dim)
    torch.distributed.all_gather(input_list, first_dim, group=group)
    all_input_list = torch.cat(input_list, dim=0).contiguous()
    max_length = torch.max(all_input_list)

    # if the size are different than the max, extend the tensor
    # accordingly
    if max_length > current_length:
        padding=tuple([0] * (input_.dim() * 2 - 1)) + \
            tuple([max_length - current_length])
        input_ = F.pad(input=input_, pad=padding)

    return input_

def orqa(Dataset):

    def cross_entropy_forward_step(batch, model):
@@ -47,6 +69,8 @@ def orqa(Dataset):
        except BaseException:
            batch_ = batch

        group, rank, world_size = get_group_world_size_rank()

        query_tokens, query_mask, query_types, query_pad_mask, \
        context_tokens, context_mask, context_types, context_pad_mask, \
        neg_context_tokens, neg_context_mask, neg_context_types, \
@@ -61,6 +85,14 @@ def orqa(Dataset):
            query_list.append(tokenizer.decode(query_tokens[i].tolist()))
            context_list.append(tokenizer.decode(context_tokens[i].tolist()))

        if neg_context_tokens is not None:
            neg_context_tokens = check_and_append_tensor_for_gather(group,
                rank, world_size, neg_context_tokens)
            neg_context_mask = check_and_append_tensor_for_gather(group,
                rank, world_size, neg_context_mask)
            neg_context_types = check_and_append_tensor_for_gather(group,
                rank, world_size, neg_context_types)

        if neg_context_tokens is not None:
            context_tokens = torch.cat([context_tokens, neg_context_tokens])
            context_mask = torch.cat([context_mask, neg_context_mask])
@@ -70,7 +102,6 @@ def orqa(Dataset):
        output_tensor = model(query_tokens, query_mask,
                                        query_types, context_tokens,
                                        context_mask, context_types)

        return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)