Commit fac6718a authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'blendable_dataset' into 'megatron_sampler'

Blendable dataset

See merge request ADLR/megatron-lm!178
parents cebd3b8b 1eda0a17
Loading
Loading
Loading
Loading
+11 −6
Original line number Diff line number Diff line
@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
        print('-------------------- arguments --------------------', flush=True)
        print('------------------------ arguments ------------------------',
              flush=True)
        str_list = []
        for arg in vars(args):
            dots = '.' * (32 - len(arg))
            dots = '.' * (48 - len(arg))
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
        print('---------------- end of arguments ----------------', flush=True)
        print('-------------------- end of arguments ---------------------',
              flush=True)


def _check_arg_is_not_none(args, arg):
@@ -278,7 +280,7 @@ def _add_learning_rate_args(parser):
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
                       choices=['constant', 'linear', 'cosine', 'exponential'],
                       choices=['constant', 'linear', 'cosine'],
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
@@ -400,8 +402,11 @@ def _add_validation_args(parser):
def _add_data_args(parser):
    group = parser.add_argument_group(title='data and dataloader')

    group.add_argument('--data-path', type=str, default=None,
                       help='Path to combined dataset to split.')
    group.add_argument('--data-path', nargs='*', default=None,
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
    group.add_argument('--split', type=str, default='969, 30, 1',
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
+75 −0
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Blendable dataset."""

import time

import numpy as np
import torch

from megatron import print_rank_0
from megatron import mpu


class BlendableDataset(torch.utils.data.Dataset):


    def __init__(self, datasets, weights):

        self.datasets = datasets
        num_datasets = len(datasets)
        assert num_datasets == len(weights)

        self.size = 0
        for dataset in self.datasets:
            self.size += len(dataset)

        # Normalize weights.
        weights = np.array(weights, dtype=np.float64)
        sum_weights = np.sum(weights)
        assert sum_weights > 0.0
        weights /= sum_weights

        # Build indecies.
        start_time = time.time()
        assert num_datasets < 255
        self.dataset_index = np.zeros(self.size, dtype=np.uint8)
        self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)

        if torch.distributed.get_rank() == 0:
            from megatron.data.dataset_utils import compile_helper
            compile_helper()
        # Simple barrier
        tmp = torch.cuda.LongTensor([1])
        torch.distributed.all_reduce(tmp, group=mpu.get_data_parallel_group())

        from megatron.data import helpers
        helpers.build_blending_indices(self.dataset_index,
                                       self.dataset_sample_index,
                                       weights, num_datasets, self.size,
                                       torch.distributed.get_rank() == 0)
        print_rank_0('> elapsed time for building blendable dataset indices: '
                     '{:.2f} (sec)'.format(time.time() - start_time))


    def __len__(self):
        return self.size


    def __getitem__(self, idx):
        dataset_idx = self.dataset_index[idx]
        sample_idx = self.dataset_sample_index[idx]
        return self.datasets[dataset_idx][sample_idx]
+74 −0
Original line number Diff line number Diff line
@@ -18,11 +18,13 @@
#   https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.

import math
import time
import collections

import numpy as np
from megatron import get_args, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset

DSET_TYPE_STD = 'standard_bert'
@@ -31,6 +33,38 @@ DSET_TYPE_ICT = 'ict'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]


def get_datasets_weights_and_num_samples(data_prefix,
                                         train_valid_test_num_samples):

    # The data prefix should be in the format of:
    #   weight-1, data-prefix-1, weight-2, data-prefix-2, ..
    assert len(data_prefix) % 2 == 0
    num_datasets = len(data_prefix) // 2
    weights = [0]*num_datasets
    prefixes = [0]*num_datasets
    for i in range(num_datasets):
        weights[i] = float(data_prefix[2*i])
        prefixes[i] = (data_prefix[2*i+1]).strip()
    # Normalize weights
    weight_sum = 0.0
    for weight in weights:
        weight_sum += weight
    assert weight_sum > 0.0
    weights = [weight / weight_sum for weight in weights]

    # Add 0.5% (the 1.005 factor) so in case the bleding dataset does
    # not uniformly distribute the number of samples, we still have
    # samples left to feed to the network.
    datasets_train_valid_test_num_samples = []
    for weight in weights:
        datasets_train_valid_test_num_samples.append(
            [int(math.ceil(val * weight * 1.005))
             for val in train_valid_test_num_samples])


    return prefixes, weights, datasets_train_valid_test_num_samples


def compile_helper():
    """Compile helper function ar runtime. Make sure this
    is invoked on a single process."""
@@ -360,6 +394,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                    short_seq_prob, seed, skip_warmup,
                                    dataset_type='standard_bert'):

    if len(data_prefix) == 1:
        return _build_train_valid_test_datasets(data_prefix[0],
                                                data_impl, splits_string,
                                                train_valid_test_num_samples,
                                                max_seq_length, masked_lm_prob,
                                                short_seq_prob, seed,
                                                skip_warmup,
                                                dataset_type=dataset_type)
    # Blending dataset.
    # Parse the values.
    output = get_datasets_weights_and_num_samples(data_prefix,
                                                  train_valid_test_num_samples)
    prefixes, weights, datasets_train_valid_test_num_samples = output

    # Build individual datasets.
    train_datasets = []
    valid_datasets = []
    test_datasets = []
    for i in range(len(prefixes)):
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
            prefixes[i], data_impl, splits_string,
            datasets_train_valid_test_num_samples[i],
            max_seq_length, masked_lm_prob, short_seq_prob,
            seed, skip_warmup, dataset_type=dataset_type)

    # Blend.
    blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = BlendableDataset(test_datasets, weights)

    return (blending_train_dataset, blending_valid_dataset,
            blending_test_dataset)


def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                     train_valid_test_num_samples,
                                     max_seq_length, masked_lm_prob,
                                     short_seq_prob, seed, skip_warmup,
                                     dataset_type='standard_bert'):
    
    if dataset_type not in DSET_TYPES:
        raise ValueError("Invalid dataset_type: ", dataset_type)

+42 −0
Original line number Diff line number Diff line
@@ -22,6 +22,8 @@ import numpy as np
import torch

from megatron import mpu, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset

@@ -31,6 +33,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                    seq_length, seed, skip_warmup):
    """Build train, valid, and test datasets."""

    # Single dataset.
    if len(data_prefix) == 1:
        return _build_train_valid_test_datasets(data_prefix[0],
                                                data_impl, splits_string,
                                                train_valid_test_num_samples,
                                                seq_length, seed, skip_warmup)

    # Blending dataset.
    # Parse the values.
    output = get_datasets_weights_and_num_samples(data_prefix,
                                                  train_valid_test_num_samples)
    prefixes, weights, datasets_train_valid_test_num_samples = output

    # Build individual datasets.
    train_datasets = []
    valid_datasets = []
    test_datasets = []
    for i in range(len(prefixes)):
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
            prefixes[i], data_impl, splits_string,
            datasets_train_valid_test_num_samples[i],
            seq_length, seed, skip_warmup)
        train_datasets.append(train_ds)
        valid_datasets.append(valid_ds)
        test_datasets.append(test_ds)

    # Blend.
    blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = BlendableDataset(test_datasets, weights)

    return (blending_train_dataset, blending_valid_dataset,
            blending_test_dataset)


def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
                                     train_valid_test_num_samples,
                                     seq_length, seed, skip_warmup):
    """Build train, valid, and test datasets."""

    # Indexed dataset.
    indexed_dataset = get_indexed_dataset_(data_prefix,
                                           data_impl,
+64 −0
Original line number Diff line number Diff line
@@ -33,6 +33,69 @@ using namespace std;
const int32_t LONG_SENTENCE_LEN = 512;


void build_blending_indices(py::array_t<uint8_t>& dataset_index,
			    py::array_t<int64_t>& dataset_sample_index,
			    const py::array_t<double>& weights,
			    const int32_t num_datasets,
			    const int64_t size, const bool verbose) {
  /* Given multiple datasets and a weighting array, build samples
   such that it follows those wieghts.*/

  if (verbose) {
    std::cout << "> building indices for blendable datasets ..." << std::endl;
  }

  // Get the pointer access without the checks.
  auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
  auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
  auto weights_ptr = weights.unchecked<1>();

  // Initialize buffer for number of samples used for each dataset.
  int64_t current_samples[num_datasets];
  for(int64_t i = 0; i < num_datasets; ++i) {
    current_samples[i] = 0;
  }

  // For each sample:
  for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {

    // Determine where the max error in sampling is happening.
    auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
    int64_t max_error_index = 0;
    double max_error = weights_ptr[0] * sample_idx_double -
      static_cast<double>(current_samples[0]);
    for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
      double error = weights_ptr[dataset_idx] * sample_idx_double -
	static_cast<double>(current_samples[dataset_idx]);
      if (error > max_error) {
	max_error = error;
	max_error_index = dataset_idx;
      }
    }

    // Populate the indices.
    dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
    dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];

    // Update the total samples.
    current_samples[max_error_index] += 1;
    
  }

  // print info
  if (verbose) {
    std::cout << " > sample ratios:" << std::endl;
    for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
      auto ratio = static_cast<double>(current_samples[dataset_idx]) /
	static_cast<double>(size);
      std::cout << "   dataset " << dataset_idx << ", input: " <<
	weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; 
    }
  }

}


py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
			   const py::array_t<int32_t>& doc_idx_,
			   const int32_t seq_length,
@@ -640,4 +703,5 @@ PYBIND11_MODULE(helpers, m) {
    m.def("build_mapping", &build_mapping);
    m.def("build_blocks_mapping", &build_blocks_mapping);
    m.def("build_sample_idx", &build_sample_idx);
    m.def("build_blending_indices", &build_blending_indices);
}
Loading