Commit 7577931b authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Fixed issues with ICT pretraining

parent 8e44d619
Loading
Loading
Loading
Loading

examples/create_embeddings.sh

deleted100644 → 0
+0 −32
Original line number Diff line number Diff line
#!/bin/bash

# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)

RANK=0
WORLD_SIZE=1

# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path to store embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>

python tools/create_doc_index.py \
    --num-layers 12 \
    --hidden-size 768 \
    --num-attention-heads 12 \
    --tensor-model-parallel-size 1 \
    --micro-batch-size 128 \
    --checkpoint-activations \
    --seq-length 512 \
    --retriever-seq-length 256 \
    --max-position-embeddings 512 \
    --load ${CHECKPOINT_PATH} \
    --evidence-data-path ${EVIDENCE_DATA_DIR} \
    --embedding-path ${EMBEDDING_PATH} \
    --indexer-log-interval 1000 \
    --indexer-batch-size 128 \
    --vocab-file bert-vocab.txt \
    --num-workers 2 \
    --fp16
+29 −20
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""

from functools import partial
import math

import torch
@@ -31,14 +33,15 @@ from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group


def pretrain_ict_model_provider():
def pretrain_ict_model_provider(pre_process=True, post_process=True):
    args = get_args()

    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                biencoder_shared_query_context_model=\
                args.biencoder_shared_query_context_model)
                args.biencoder_shared_query_context_model,
                pre_process=pre_process, post_process=post_process)

    return model

@@ -79,25 +82,9 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
        output = output_list[rank].contiguous()
        return output

def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
def loss_func(output_tensor):
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    query_tokens, query_mask, \
    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
    timers('batch-generator').stop()

    # Query and Context Types
    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)

    # Forward model.
    query_logits, context_logits = model(query_tokens, query_mask,
                                    query_types, context_tokens,
                                    context_mask, context_types)
    query_logits, context_logits = output_tensor

    micro_batch_size = query_logits.shape[0]
    # recall we assert that tensor_model_parallel_size == 1
@@ -139,6 +126,28 @@ def forward_step(data_iterator, model, input_tensor):
    return loss, stats_dict



def forward_step(data_iterator, model):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    query_tokens, query_mask, \
    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
    timers('batch-generator').stop()

    # Query and Context Types
    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)

    # Forward model.
    output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
                        context_mask, context_types)

    return output_tensor, partial(loss_func)

def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
    args = get_args()

tools/create_doc_index.py

deleted100644 → 0
+0 −33
Original line number Diff line number Diff line
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

from megatron import print_rank_0
from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron


def main():
    """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
    - Include all args needed for initial model specification

    Other key args:
        --block-data-path: path to write to
        --ict-load or --realm-load: path to checkpoint with which to embed
        --data-path and --titles-data-path: paths for dataset
        --indexer-log-interval: reporting interval
        --indexer-batch-size: size specific for indexer jobs

    Check README.md for example script
    """

    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    index_builder = IndexBuilder()
    index_builder.build_and_save_index()
    print_rank_0("Build and save indices: done!")

if __name__ == "__main__":
    main()