Commit 48a5e0da authored by Jared Casper's avatar Jared Casper
Browse files

Integrate code from t5_main into existing code.

parent f32a638d
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -625,6 +625,9 @@ def _add_data_args(parser):
                       help='Path to the vocab file.')
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
    group.add_argument('--vocab-extra-ids', type=int, default=0,
                       help='Number of additional vocabulary tokens. '
                            'They are used for span masking in the T5 model')
    group.add_argument('--seq-length', type=int, default=None,
                       help='Maximum sequence length to process.')
    group.add_argument('--encoder-seq-length', type=int, default=None,
+56 −108
Original line number Diff line number Diff line
@@ -15,24 +15,25 @@

"""BERT Style dataset."""

import os
import time

import numpy as np
import torch
from torch.utils.data import Dataset

from megatron import get_tokenizer, get_args
from megatron import print_rank_0
from megatron import mpu
from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments
from megatron.data.dataset_utils import create_tokens_and_tokentypes
from megatron.data.dataset_utils import pad_and_convert_to_numpy
from megatron.data.dataset_utils import create_masked_lm_predictions
from megatron import (
    get_args,
    get_tokenizer,
    mpu,
    print_rank_0
)
from megatron.data.dataset_utils import (
    get_samples_mapping,
    get_a_and_b_segments,
    truncate_segments,
    create_tokens_and_tokentypes,
    create_masked_lm_predictions
)


class BertDataset(Dataset):
class BertDataset(torch.utils.data.Dataset):

    def __init__(self, name, indexed_dataset, data_prefix,
                 num_epochs, max_num_samples, masked_lm_prob,
@@ -49,11 +50,11 @@ class BertDataset(Dataset):
        self.indexed_dataset = indexed_dataset

        # Build the samples mapping.
        self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
        self.samples_mapping = get_samples_mapping(self.indexed_dataset,
                                                   data_prefix,
                                                   num_epochs,
                                                   max_num_samples,
                                                    self.max_seq_length,
                                                   self.max_seq_length - 3, # account for added tokens
                                                   short_seq_prob,
                                                   self.seed,
                                                   self.name,
@@ -87,91 +88,6 @@ class BertDataset(Dataset):
                                     self.binary_head)


def get_samples_mapping_(indexed_dataset,
                         data_prefix,
                         num_epochs,
                         max_num_samples,
                         max_seq_length,
                         short_seq_prob,
                         seed,
                         name,
                         binary_head):
    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
    indexmap_filename += '_{}s'.format(seed)
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
    if torch.distributed.get_rank() == 0 and \
       not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert indexed_dataset.doc_idx.dtype == np.int64
        assert indexed_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building sapmles index mapping for {} ...'.format(
            name))
        # First compile and then import.
        from megatron.data import helpers
        samples_mapping = helpers.build_mapping(
            indexed_dataset.doc_idx,
            indexed_dataset.sizes,
            num_epochs,
            max_num_samples,
            max_seq_length - 3,  # account for added tokens
            short_seq_prob,
            seed,
            verbose,
            2 if binary_head else 1)
        print_rank_0(' > done building sapmles index maping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elasped time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
                         time.time() - start_time))
    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
    torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
    assert counts[0].item() == (
        torch.distributed.get_world_size() //
        torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
    samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        samples_mapping.shape[0]))

    return samples_mapping


def build_training_sample(sample,
@@ -225,7 +141,7 @@ def build_training_sample(sample,

    # Masking.
    max_predictions_per_seq = masked_lm_prob * max_num_tokens
    (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
    (tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
        tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
        cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)

@@ -244,3 +160,35 @@ def build_training_sample(sample,
        'truncated': int(truncated)}
    return train_sample


def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
                             masked_labels, pad_id, max_seq_length):
    """Pad sequences and convert them to numpy."""

    # Some checks.
    num_tokens = len(tokens)
    padding_length = max_seq_length - num_tokens
    assert padding_length >= 0
    assert len(tokentypes) == num_tokens
    assert len(masked_positions) == len(masked_labels)

    # Tokens and token types.
    filler = [pad_id] * padding_length
    tokens_np = np.array(tokens + filler, dtype=np.int64)
    tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)

    # Padding mask.
    padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
                               dtype=np.int64)

    # Lables and loss mask.
    labels = [-1] * max_seq_length
    loss_mask = [0] * max_seq_length
    for i in range(len(masked_positions)):
        assert masked_positions[i] < num_tokens
        labels[masked_positions[i]] = masked_labels[i]
        loss_mask[masked_positions[i]] = 1
    labels_np = np.array(labels, dtype=np.int64)
    loss_mask_np = np.array(loss_mask, dtype=np.int64)

    return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
+164 −38
Original line number Diff line number Diff line
@@ -19,18 +19,26 @@
# with some modifications.

import math
import os
import time
import collections

import numpy as np
from megatron import get_args, print_rank_0
import torch

from megatron import (
    get_args,
    mpu,
    print_rank_0
)
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset

DSET_TYPE_STD = 'standard_bert'
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5  = 't5'

DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]


def get_datasets_weights_and_num_samples(data_prefix,
@@ -179,7 +187,9 @@ def create_masked_lm_predictions(tokens,
                                 max_ngrams=3,
                                 do_whole_word_mask=True,
                                 favor_longer_ngram=False,
                                 do_permutation=False):
                                 do_permutation=False,
                                 geometric_dist=False,
                                 masking_style="bert"):
    """Creates the predictions for the masked LM objective.
    Note: Tokens here are vocab ids and not text tokens."""

@@ -219,12 +229,12 @@ def create_masked_lm_predictions(tokens,
    num_to_predict = min(max_predictions_per_seq,
                         max(1, int(round(len(tokens) * masked_lm_prob))))

    ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
    if not geometric_dist:
        # Note(mingdachen):
        # By default, we set the probilities to favor shorter ngram sequences.
    ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
        pvals = 1. / np.arange(1, max_ngrams + 1)
        pvals /= pvals.sum(keepdims=True)

        if favor_longer_ngram:
            pvals = pvals[::-1]

@@ -237,7 +247,7 @@ def create_masked_lm_predictions(tokens,

    np_rng.shuffle(ngram_indexes)

    masked_lms = []
    (masked_lms, masked_spans) = ([], [])
    covered_indexes = set()
    for cand_index_set in ngram_indexes:
        if len(masked_lms) >= num_to_predict:
@@ -251,9 +261,16 @@ def create_masked_lm_predictions(tokens,
                if index in covered_indexes:
                    continue

        if not geometric_dist:
            n = np_rng.choice(ngrams[:len(cand_index_set)],
                              p=pvals[:len(cand_index_set)] /
                              pvals[:len(cand_index_set)].sum(keepdims=True))
        else:
            # Sampling "n" from the geometric distribution and clipping it to
            # the max_ngrams. Using p=0.2 default from the SpanBERT paper
            # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
            n = min(np_rng.geometric(0.2), max_ngrams)

        index_set = sum(cand_index_set[n - 1], [])
        n -= 1
        # Note(mingdachen):
@@ -277,8 +294,8 @@ def create_masked_lm_predictions(tokens,
            continue
        for index in index_set:
            covered_indexes.add(index)

            masked_token = None
            if masking_style == "bert":
                # 80% of the time, replace with [MASK]
                if np_rng.random() < 0.8:
                    masked_token = mask_id
@@ -289,12 +306,19 @@ def create_masked_lm_predictions(tokens,
                    # 10% of the time, replace with random word
                    else:
                        masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
            elif masking_style == "t5":
                masked_token = mask_id
            else:
                raise ValueError("invalid value of masking style")

            output_tokens[index] = masked_token

            masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
    assert len(masked_lms) <= num_to_predict

        masked_spans.append(MaskedLmInstance(
            index=index_set,
            label=[tokens[index] for index in index_set]))

    assert len(masked_lms) <= num_to_predict
    np_rng.shuffle(ngram_indexes)

    select_indexes = set()
@@ -347,12 +371,13 @@ def create_masked_lm_predictions(tokens,
            masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))

    masked_lms = sorted(masked_lms, key=lambda x: x.index)
    # Sort the spans by the index of the first span
    masked_spans = sorted(masked_spans, key=lambda x: x.index[0])

    for p in masked_lms:
        masked_lm_positions.append(p.index)
        masked_lm_labels.append(p.label)

    return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
    return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)


def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
@@ -390,9 +415,10 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,

def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                    train_valid_test_num_samples,
                                    max_seq_length, masked_lm_prob,
                                    short_seq_prob, seed, skip_warmup,
                                    binary_head,
                                    max_seq_length,
                                    masked_lm_prob, short_seq_prob, seed,
                                    skip_warmup, binary_head=False,
                                    max_seq_length_dec=None,
                                    dataset_type='standard_bert'):

    if len(data_prefix) == 1:
@@ -403,6 +429,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                                short_seq_prob, seed,
                                                skip_warmup,
                                                binary_head,
                                                max_seq_length_dec,
                                                dataset_type=dataset_type)
    # Blending dataset.
    # Parse the values.
@@ -444,9 +471,10 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,

def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                     train_valid_test_num_samples,
                                     max_seq_length, masked_lm_prob,
                                     short_seq_prob, seed, skip_warmup,
                                     binary_head,
                                     max_seq_length,
                                     masked_lm_prob, short_seq_prob, seed,
                                     skip_warmup, binary_head,
                                     max_seq_length_dec,
                                     dataset_type='standard_bert'):

    if dataset_type not in DSET_TYPES:
@@ -489,6 +517,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
    def build_dataset(index, name):
        from megatron.data.bert_dataset import BertDataset
        from megatron.data.ict_dataset import ICTDataset
        from megatron.data.t5_dataset import T5Dataset
        dataset = None
        if splits[index + 1] > splits[index]:
            # Get the pointer to the original doc-idx so we can set it later.
@@ -507,7 +536,6 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                max_num_samples=train_valid_test_num_samples[index],
                max_seq_length=max_seq_length,
                seed=seed,
                binary_head=binary_head
            )

            if dataset_type == DSET_TYPE_ICT:
@@ -517,15 +545,27 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                    title_dataset=title_dataset,
                    query_in_block_prob=args.query_in_block_prob,
                    use_one_sent_docs=args.use_one_sent_docs,
                    binary_head=binary_head,
                    **kwargs
                )
            else:
            elif dataset_type == DSET_TYPE_T5:
                dataset = T5Dataset(
                    indexed_dataset=indexed_dataset,
                    masked_lm_prob=masked_lm_prob,
                    max_seq_length_dec=max_seq_length_dec,
                    short_seq_prob=short_seq_prob,
                    **kwargs
                )
            elif dataset_type == DSET_TYPE_BERT:
                dataset = BertDataset(
                    indexed_dataset=indexed_dataset,
                    masked_lm_prob=masked_lm_prob,
                    short_seq_prob=short_seq_prob,
                    binary_head=binary_head,
                    **kwargs
                )
            else:
                raise NotImplementedError("Dataset type not fully implemented.")

            # Set the original pointer so dataset remains the main dataset.
            indexed_dataset.set_doc_idx(doc_idx_ptr)
@@ -590,4 +630,90 @@ def get_train_valid_test_split_(splits_string, size):
    assert splits_index[-1] == size
    return splits_index

def get_samples_mapping(indexed_dataset,
                        data_prefix,
                        num_epochs,
                        max_num_samples,
                        max_seq_length,
                        short_seq_prob,
                        seed,
                        name,
                        binary_head):
    """Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""

    if not num_epochs:
        if not max_num_samples:
            raise ValueError("Need to specify either max_num_samples "
                             "or num_epochs")
        num_epochs = np.iinfo(np.int32).max - 1
    if not max_num_samples:
        max_num_samples = np.iinfo(np.int64).max - 1

    # Filename of the index mapping
    indexmap_filename = data_prefix
    indexmap_filename += '_{}_indexmap'.format(name)
    if num_epochs != (np.iinfo(np.int32).max - 1):
        indexmap_filename += '_{}ep'.format(num_epochs)
    if max_num_samples != (np.iinfo(np.int64).max - 1):
        indexmap_filename += '_{}mns'.format(max_num_samples)
    indexmap_filename += '_{}msl'.format(max_seq_length)
    indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
    indexmap_filename += '_{}s'.format(seed)
    indexmap_filename += '.npy'

    # Build the indexed mapping if not exist.
    if torch.distributed.get_rank() == 0 and \
       not os.path.isfile(indexmap_filename):
        print(' > WARNING: could not find index map file {}, building '
              'the indices on rank 0 ...'.format(indexmap_filename))

        # Make sure the types match the helpers input types.
        assert indexed_dataset.doc_idx.dtype == np.int64
        assert indexed_dataset.sizes.dtype == np.int32

        # Build samples mapping
        verbose = torch.distributed.get_rank() == 0
        start_time = time.time()
        print_rank_0(' > building sapmles index mapping for {} ...'.format(
            name))
        # First compile and then import.
        from megatron.data import helpers
        samples_mapping = helpers.build_mapping(
            indexed_dataset.doc_idx,
            indexed_dataset.sizes,
            num_epochs,
            max_num_samples,
            max_seq_length,
            short_seq_prob,
            seed,
            verbose,
            2 if binary_head else 1)
        print_rank_0(' > done building sapmles index maping')
        np.save(indexmap_filename, samples_mapping, allow_pickle=True)
        print_rank_0(' > saved the index mapping in {}'.format(
            indexmap_filename))
        # Make sure all the ranks have built the mapping
        print_rank_0(' > elasped time to build and save samples mapping '
                     '(seconds): {:4f}'.format(
                         time.time() - start_time))
    # This should be a barrier but nccl barrier assumes
    # device_index=rank which is not the case for model
    # parallel case
    counts = torch.cuda.LongTensor([1])
    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
    torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
    assert counts[0].item() == (
        torch.distributed.get_world_size() //
        torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))

    # Load indexed dataset.
    print_rank_0(' > loading indexed mapping from {}'.format(
        indexmap_filename))
    start_time = time.time()
    samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
    print_rank_0('    loaded indexed file in {:3.3f} seconds'.format(
        time.time() - start_time))
    print_rank_0('    total number of samples: {}'.format(
        samples_mapping.shape[0]))

    return samples_mapping
+270 −0

File added.

Preview size limit exceeded, changes collapsed.

+1 −0
Original line number Diff line number Diff line
@@ -18,5 +18,6 @@ from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from .distributed import DistributedDataParallel
from .bert_model import BertModel
from .gpt_model import GPTModel
from .t5_model import T5Model
from .language_model import get_language_model
from .module import Float16Module
Loading