Commit 83c4d95a authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge_dpr' into 'main'

Finetuning retriever (ICT+DPR)

See merge request ADLR/megatron-lm!277
parents 01fc0833 fda81a21
Loading
Loading
Loading
Loading

examples/create_embeddings.sh

deleted100644 → 0
+0 −32
Original line number Diff line number Diff line
#!/bin/bash

# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)

RANK=0
WORLD_SIZE=1

# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path to store embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>

python tools/create_doc_index.py \
    --num-layers 12 \
    --hidden-size 768 \
    --num-attention-heads 12 \
    --tensor-model-parallel-size 1 \
    --micro-batch-size 128 \
    --checkpoint-activations \
    --seq-length 512 \
    --retriever-seq-length 256 \
    --max-position-embeddings 512 \
    --load ${CHECKPOINT_PATH} \
    --evidence-data-path ${EVIDENCE_DATA_DIR} \
    --embedding-path ${EMBEDDING_PATH} \
    --indexer-log-interval 1000 \
    --indexer-batch-size 128 \
    --vocab-file bert-vocab.txt \
    --num-workers 2 \
    --fp16
+8 −5
Original line number Diff line number Diff line
#!/bin/bash

# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model
# ICT model or a finetuned model for Natural Question task

# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py

EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path of the embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
CHECKPOINT_PATH=<Specify path of pretrained ICT model or finetuned model>

QA_FILE=<Path of the natural question test dataset>
QA_FILE=<Path of the natural question dev or test dataset>

python tasks/main.py \
    --task ICT-ZEROSHOT-NQ \
    --task RETRIEVER-EVAL \
    --tokenizer-type BertWordPieceLowerCase \
    --num-layers 12 \
    --hidden-size 768 \
@@ -32,5 +32,8 @@ python tasks/main.py \
    --num-workers 2 \
    --faiss-use-gpu \
    --retriever-report-topk-accuracies 1 5 20 100 \
    --fp16
    --fp16 \
    --indexer-log-interval 1000 \
    --indexer-batch-size 128

+56 −0
Original line number Diff line number Diff line
#!/bin/bash

# Finetune a BERT or pretrained ICT model using Google natural question data 
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py

WORLD_SIZE=8

DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
                  --nnodes 1 \
                  --node_rank 0 \
                  --master_addr localhost \
                  --master_port 6000"

CHECKPOINT_PATH=<Specify path for the finetuned retriever model>

# Load either of the below
BERT_LOAD_PATH=<Path of BERT pretrained model>
PRETRAINED_CHECKPOINT=<Path of Pretrained ICT model>

python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
        --task RET-FINETUNE-NQ \
        --train-with-neg \
        --train-hard-neg 1 \
        --pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
        --num-layers 12 \
        --hidden-size 768 \
        --num-attention-heads 12 \
        --tensor-model-parallel-size 1 \
        --tokenizer-type BertWordPieceLowerCase \
        --train-data nq-train.json \
        --valid-data nq-dev.json \
        --save ${CHECKPOINT_PATH} \
        --load ${CHECKPOINT_PATH} \
        --vocab-file bert-vocab.txt \
        --bert-load ${BERT_LOAD_PATH} \
        --save-interval 5000 \
        --log-interval 10 \
        --eval-interval 25000 \
        --eval-iters 100 \
        --indexer-log-interval 1000 \
        --faiss-use-gpu \
        --DDP-impl torch \
        --fp16 \
        --retriever-report-topk-accuracies 1 5 10 20 100 \
        --seq-length 512 \
        --retriever-seq-length 256 \
        --max-position-embeddings 512 \
        --retriever-score-scaling \
        --epochs 80 \
        --micro-batch-size 8 \
        --eval-micro-batch-size 16 \
        --indexer-batch-size 128 \
        --lr 2e-5 \
        --lr-warmup-fraction 0.01 \
        --weight-decay 1e-1
+7 −5
Original line number Diff line number Diff line
import sys
import time
import torch
import torch.distributed as dist

from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model


@@ -29,7 +30,6 @@ class IndexBuilder(object):
        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
        assert not (args.load and args.ict_load)
        #self.using_realm_chkpt = args.ict_load is None

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size
@@ -47,8 +47,8 @@ class IndexBuilder(object):
        if self.biencoder_shared_query_context_model:
            only_context_model = False

        model = get_model(lambda: biencoder_model_provider(only_context_model \
            = only_context_model, biencoder_shared_query_context_model = \
        model = get_model(get_model_provider(only_context_model=\
            only_context_model, biencoder_shared_query_context_model=\
            self.biencoder_shared_query_context_model))

        self.model = load_biencoder_checkpoint(model,
@@ -85,6 +85,7 @@ class IndexBuilder(object):
        """
        assert len(self.model) == 1
        unwrapped_model = self.model[0]

        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module

@@ -103,6 +104,7 @@ class IndexBuilder(object):
            context_logits = unwrapped_model.embed_text(
                unwrapped_model.context_model, context_tokens, context_mask,
                context_types)

            context_logits = detach(context_logits)
            row_id = detach(row_id)

+44 −9
Original line number Diff line number Diff line
@@ -15,11 +15,30 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule

def get_model_provider(only_query_model=False, only_context_model=False,
        biencoder_shared_query_context_model=False):

    def model_provider(pre_process=True, post_process=True):
        """Build the model."""

        print_rank_0('building Bienoder model ...')
        model = biencoder_model_provider(only_query_model=only_query_model,
                only_context_model = only_context_model,
                biencoder_shared_query_context_model = \
                biencoder_shared_query_context_model,
                pre_process=pre_process, post_process=post_process)

        return model

    return model_provider


def biencoder_model_provider(only_query_model=False,
                             only_context_model=False,
                             biencoder_shared_query_context_model=False):
                             biencoder_shared_query_context_model=False,
                             pre_process=True,
                             post_process=True):
    """Build the model."""
    args = get_args()

    assert mpu.get_tensor_model_parallel_world_size() == 1 and \
        mpu.get_pipeline_model_parallel_world_size() == 1, \
@@ -35,7 +54,9 @@ def biencoder_model_provider(only_query_model=False,
        only_query_model=only_query_model,
        only_context_model=only_context_model,
        biencoder_shared_query_context_model=\
            biencoder_shared_query_context_model)
        biencoder_shared_query_context_model,
        pre_process=pre_process,
        post_process=post_process)

    return model

@@ -48,13 +69,17 @@ class BiEncoderModel(MegatronModule):
                 parallel_output=True,
                 only_query_model=False,
                 only_context_model=False,
                 biencoder_shared_query_context_model=False):
                 biencoder_shared_query_context_model=False,
                 pre_process=True,
                 post_process=True):
        super(BiEncoderModel, self).__init__()
        args = get_args()

        bert_kwargs = dict(
            num_tokentypes=num_tokentypes,
            parallel_output=parallel_output)
            parallel_output=parallel_output,
            pre_process=pre_process,
            post_process=post_process)

        self.biencoder_shared_query_context_model = \
            biencoder_shared_query_context_model
@@ -78,6 +103,13 @@ class BiEncoderModel(MegatronModule):
                self.context_model = PretrainedBertModel(**bert_kwargs)
                self._context_key = 'context_model'

    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        # this is just a placeholder and will be needed when model
        # parallelism will be used
        # self.language_model.set_input_tensor(input_tensor)
        return

    def forward(self, query_tokens, query_attention_mask, query_types,
                context_tokens, context_attention_mask, context_types):
        """Run a forward pass for each of the models and
@@ -217,7 +249,7 @@ class PretrainedBertModel(MegatronModule):
    learned information retrieval."""

    def __init__(self, num_tokentypes=2,
            parallel_output=True):
            parallel_output=True, pre_process=True, post_process=True):
        super(PretrainedBertModel, self).__init__()

        args = get_args()
@@ -225,6 +257,8 @@ class PretrainedBertModel(MegatronModule):
        self.pad_id = tokenizer.pad
        self.biencoder_projection_dim = args.biencoder_projection_dim
        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(
            args.init_method_std, args.num_layers)
@@ -234,7 +268,9 @@ class PretrainedBertModel(MegatronModule):
            add_pooler=False,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method)
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process)

        if args.biencoder_projection_dim > 0:
            self.projection_enc = get_linear_layer(args.hidden_size,
@@ -247,7 +283,6 @@ class PretrainedBertModel(MegatronModule):
        #extended_attention_mask = bert_extended_attention_mask(attention_mask)
        position_ids = bert_position_ids(input_ids)


        lm_output = self.language_model(input_ids,
                                        position_ids,
                                        extended_attention_mask,
@@ -285,7 +320,7 @@ class PretrainedBertModel(MegatronModule):

    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""
        print_rank_0("loading BERT weights")
        print_rank_0("loading pretrained weights")
        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)

Loading