Commit 8d7f508a authored by Neel Kant's avatar Neel Kant
Browse files

Addressed Jared's comments

parent 03feecbc
Loading
Loading
Loading
Loading
+14 −21
Original line number Diff line number Diff line
@@ -272,15 +272,15 @@ Loosely, they are pretraining the retriever modules, then jointly training the l
### Inverse Cloze Task (ICT) Pretraining
1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. 
Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body. 
Refer to the following script meant to be run in an interactive session on draco: 
Refer to the following script 
<pre>
python preprocess_data.py \
    --input /home/universal-lm-data.cosmos549/datasets/wikipedia/wikidump_lines.json \
    --input /path/to/corpus.json \
    --json-keys text title \
    --split-sentences \
    --tokenizer-type BertWordPieceLowerCase \
    --vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt \
    --output-prefix wiki_indexed \
    --vocab-file /path/to/vocab.txt \
    --output-prefix corpus_indexed \
    --workers 5  # works well for 10 CPU cores. Scale up accordingly.
</pre>

@@ -288,13 +288,10 @@ python preprocess_data.py \
 The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block. 
3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task.
In REALM, this is an uncased bert base model trained with the standard hyperparameters.
4. Use `pretrain_bert_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with. 
The script below trains the ICT model from REALM on draco. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument.
4. Use `pretrain_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with. 
The script below trains the ICT model from REALM. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The batch size used in the paper is 4096, so this would need to be run with data parallel world size 32. 
<pre>
EXPNAME="ict_wikipedia"
CHKPT="chkpts/${EXPNAME}"
LOGDIR="logs/${EXPNAME}"
COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch python pretrain_bert_ict.py \
python pretrain_ict.py \
    --num-layers 12 \
    --num-attention-heads 12 \
    --hidden-size 768 \
@@ -304,13 +301,12 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch
    --ict-head-size 128 \
    --train-iters 100000 \
    --checkpoint-activations \
    --bert-load /home/dcg-adlr-nkant-output.cosmos1203/chkpts/base_bert_seq256 \
    --load CHKPT \
    --save CHKPT \
    --data-path /home/dcg-adlr-nkant-data.cosmos1202/wiki/wikipedia_lines \
    --titles-data-path /home/dcg-adlr-nkant-data.cosmos1202/wiki/wikipedia_lines-titles \
    --vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt \
    --distributed-backend nccl \
    --bert-load /path/to/pretrained_bert \
    --load checkpoints \
    --save checkpoints \
    --data-path /path/to/indexed_dataset \
    --titles-data-path /path/to/titles_indexed_dataset \
    --vocab-file /path/to/vocab.txt \
    --lr 0.0001 \
    --num-workers 2 \
    --lr-decay-style linear \
@@ -319,11 +315,8 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch
    --warmup .01 \
    --save-interval 3000 \
    --query-in-block-prob 0.1 \
    --fp16 \
    --adlr-autoresume \
    --adlr-autoresume-interval 100"
    --fp16
    
submit_job --image 'http://gitlab-master.nvidia.com/adlr/megatron-lm/megatron:20.03_faiss' --mounts /home/universal-lm-data.cosmos549,/home/dcg-adlr-nkant-data.cosmos1202,/home/dcg-adlr-nkant-output.cosmos1203,/home/nkant --name "${EXPNAME}" --partition batch_32GB --gpu 8 --nodes 4 --autoresume_timer 420 -c "${COMMAND}" --logdir "${LOGDIR}"
</pre>

<a id="evaluation-and-tasks"></a>
+26 −12
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ def parse_args(extra_args_provider=None, defaults={},
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
    parser = _add_realm_args(parser)

    # Custom arguments.
    if extra_args_provider is not None:
@@ -139,8 +140,6 @@ def _add_network_size_args(parser):
                       '    grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
    group.add_argument('--hidden-size', type=int, default=None,
                       help='Tansformer hidden size.')
    group.add_argument('--ict-head-size', type=int, default=None,
                       help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
    group.add_argument('--num-attention-heads', type=int, default=None,
                       help='Number of transformer attention heads.')
    group.add_argument('--max-position-embeddings', type=int, default=None,
@@ -264,10 +263,6 @@ def _add_checkpointing_args(parser):
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
                       help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
    group.add_argument('--no-load-optim', action='store_true',
                       help='Do not load optimizer when loading checkpoint.')
    group.add_argument('--no-load-rng', action='store_true',
@@ -347,8 +342,6 @@ def _add_data_args(parser):

    group.add_argument('--data-path', type=str, default=None,
                       help='Path to combined dataset to split.')
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--split', type=str, default='969, 30, 1',
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
@@ -384,10 +377,6 @@ def _add_data_args(parser):
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
                       help='Probability of keeping query in block for ICT dataset')
    group.add_argument('--ict-one-sent', action='store_true',
                       help='Whether to use one sentence documents in ICT')

    return parser

@@ -402,3 +391,28 @@ def _add_autoresume_args(parser):
                       'termination signal')

    return parser


def _add_realm_args(parser):
    group = parser.add_argument_group(title='realm')

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
                       help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
                       help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
                       help='Probability of keeping query in block for ICT dataset')
    group.add_argument('--ict-one-sent', action='store_true',
                       help='Whether to use one sentence documents in ICT')

    return parser
+2 −1
Original line number Diff line number Diff line
@@ -23,8 +23,9 @@ import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import mpu, print_rank_0
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0


def check_checkpoint_args(checkpoint_args):
+2 −1
Original line number Diff line number Diff line
@@ -22,7 +22,8 @@ import numpy as np
import torch
from torch.utils.data import Dataset

from megatron import get_tokenizer, get_args, print_rank_0
from megatron import get_tokenizer, get_args
from megatron import print_rank_0
from megatron import mpu
from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments
+1 −1
Original line number Diff line number Diff line
@@ -399,7 +399,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,

    def build_dataset(index, name):
        from megatron.data.bert_dataset import BertDataset
        from megatron.data.realm_dataset import ICTDataset
        from megatron.data.ict_dataset import ICTDataset
        dataset = None
        if splits[index + 1] > splits[index]:
            # Get the pointer to the original doc-idx so we can set it later.
Loading