Commit 2e0b3fca authored by Steven Steinke's avatar Steven Steinke
Browse files

Fixed minor inconsistencies in scripts, added distributed comment

parent 050c24d5
Loading
Loading
Loading
Loading
+20 −14
Original line number Diff line number Diff line
@@ -284,7 +284,7 @@ WORLD_SIZE=$MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \

</pre>

Several downstream tasks are described for both GPT-2 and BERT models below.
Several downstream tasks are described for both GPT-2 and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.

<a id="gpt-2-text-generation"></a>
## GPT-2 Text Generation
@@ -323,11 +323,11 @@ We include example scripts for GPT-2 evaluation on WikiText perplexity evaluatio
### WikiText Perplexity Evaluation
For even comparison with prior works, we evaluate perplexity on the word-level [WikiText-103 test dataset](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip), and appropriately compute perplexity given the change in tokens when using our subword tokenizer.

We use the following command to run WikiText-103 evaluation on a 345M parameter model:
We use the following command to run WikiText-103 evaluation on a 345M parameter model. Make that `wikitext` is part of the file path.
<pre>
TASK="WIKITEXT103"

VALID_DATA=&#60;wikitext path&#62;
VALID_DATA=&#60;wikitext path&#62;.txt
VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
CHECKPOINT_PATH=checkpoints/gpt2_345m
@@ -335,8 +335,8 @@ CHECKPOINT_PATH=checkpoints/gpt2_345m
COMMON_TASK_ARGS="--num-layers 24 \
                  --hidden-size 1024 \
                  --num-attention-heads 16 \
                  --seq-length 512 \
                  --max-position-embeddings 512 \
                  --seq-length 1024 \
                  --max-position-embeddings 1024 \
                  --fp16 \
                  --vocab-file $VOCAB_FILE"

@@ -359,12 +359,12 @@ python tasks/main.py \
### LAMBADA Cloze Accuracy
To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceeding tokens) we utilize a detokenized, processed version of the [LAMBADA dataset](https://github.com/cybertronai/bflm/blob/master/lambada_test.jsonl).

We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the `--strict-lambada` flag should be used to require whole word matching.
We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the `--strict-lambada` flag should be used to require whole word matching. Make that `lambada` is part of the file path.

<pre>
TASK="LAMBADA"

VALID_DATA=&#60;lambada path&#62;
VALID_DATA=&#60;lambada path&#62;.json
VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
CHECKPOINT_PATH=checkpoints/gpt2_345m
@@ -400,17 +400,23 @@ VALID_DATA="data/RACE/dev/middle \
VOCAB_FILE=bert-vocab.txt
PRETRAINED_CHECKPOINT=checkpoints/bert_345m
CHECKPOINT_PATH=checkpoints/bert_345m_race
COMMON_TASK_ARGS=&#60;same as those in <a href="#wikitext-perplexity-evaluation">WikiText Perplexity Evaluation</a> above&#62;
COMMON_TASK_ARGS=COMMON_TASK_ARGS="--num-layers 24 \
                  --hidden-size 1024 \
                  --num-attention-heads 16 \
                  --seq-length 512 \
                  --max-position-embeddings 512 \
                  --fp16 \
                  --vocab-file $VOCAB_FILE"

COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
                      --valid-data $VALID_DATA \
                      --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
                      --checkpoint-activations \
                      --save-interval 500000 \
                      --save-interval 10000 \
                      --save $CHECKPOINT_PATH \
                      --log-interval 10 \
                      --eval-interval 100 \
                      --eval-iters 50 \
                      --log-interval 100 \
                      --eval-interval 1000 \
                      --eval-iters 10 \
                      --weight-decay 1.0e-1"

python tasks/main.py \
@@ -436,8 +442,8 @@ VALID_DATA="data/glue_data/MNLI/dev_matched.tsv \
PRETRAINED_CHECKPOINT=checkpoints/bert_345m
VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m_mnli
COMMON_TASK_ARGS=&#60;same as those in <a href="#lambada-cloze-accuracy">LAMBADA Cloze Accuracy</a> above&#62;
COMMON_TASK_ARGS_EXT=&#60;same as those in <a href="#race-evaluation">Race Evaluation</a> above&#62;
COMMON_TASK_ARGS=&#60;same as those in <a href="#race-evaluation">RACE Evaluation</a> above&#62;
COMMON_TASK_ARGS_EXT=&#60;same as those in <a href="#race-evaluation">RACE Evaluation</a> above&#62;

python tasks/main.py \
       --task MNLI \