Loading megatron/arguments.py +11 −6 Original line number Diff line number Diff line Loading @@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={}, def _print_args(args): """Print arguments.""" if args.rank == 0: print('-------------------- arguments --------------------', flush=True) print('------------------------ arguments ------------------------', flush=True) str_list = [] for arg in vars(args): dots = '.' * (32 - len(arg)) dots = '.' * (48 - len(arg)) str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) for arg in sorted(str_list, key=lambda x: x.lower()): print(arg, flush=True) print('---------------- end of arguments ----------------', flush=True) print('-------------------- end of arguments ---------------------', flush=True) def _check_arg_is_not_none(args, arg): Loading Loading @@ -278,7 +280,7 @@ def _add_learning_rate_args(parser): 'and initial warmup, the learing rate at each ' 'iteration would be different.') group.add_argument('--lr-decay-style', type=str, default='linear', choices=['constant', 'linear', 'cosine', 'exponential'], choices=['constant', 'linear', 'cosine'], help='Learning rate decay function.') group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' Loading Loading @@ -400,8 +402,11 @@ def _add_validation_args(parser): def _add_data_args(parser): group = parser.add_argument_group(title='data and dataloader') group.add_argument('--data-path', type=str, default=None, help='Path to combined dataset to split.') group.add_argument('--data-path', nargs='*', default=None, help='Path to the training dataset. Accepted format:' '1) a single data path, 2) multiple datasets in the' 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' Loading megatron/data/blendable_dataset.py 0 → 100644 +75 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Blendable dataset.""" import time import numpy as np import torch from megatron import print_rank_0 from megatron import mpu class BlendableDataset(torch.utils.data.Dataset): def __init__(self, datasets, weights): self.datasets = datasets num_datasets = len(datasets) assert num_datasets == len(weights) self.size = 0 for dataset in self.datasets: self.size += len(dataset) # Normalize weights. weights = np.array(weights, dtype=np.float64) sum_weights = np.sum(weights) assert sum_weights > 0.0 weights /= sum_weights # Build indecies. start_time = time.time() assert num_datasets < 255 self.dataset_index = np.zeros(self.size, dtype=np.uint8) self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) if torch.distributed.get_rank() == 0: from megatron.data.dataset_utils import compile_helper compile_helper() # Simple barrier tmp = torch.cuda.LongTensor([1]) torch.distributed.all_reduce(tmp, group=mpu.get_data_parallel_group()) from megatron.data import helpers helpers.build_blending_indices(self.dataset_index, self.dataset_sample_index, weights, num_datasets, self.size, torch.distributed.get_rank() == 0) print_rank_0('> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time)) def __len__(self): return self.size def __getitem__(self, idx): dataset_idx = self.dataset_index[idx] sample_idx = self.dataset_sample_index[idx] return self.datasets[dataset_idx][sample_idx] megatron/data/dataset_utils.py +74 −0 Original line number Diff line number Diff line Loading @@ -18,11 +18,13 @@ # https://github.com/google-research/albert/blob/master/create_pretraining_data.py # with some modifications. import math import time import collections import numpy as np from megatron import get_args, print_rank_0 from megatron.data.blendable_dataset import BlendableDataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset DSET_TYPE_STD = 'standard_bert' Loading @@ -31,6 +33,38 @@ DSET_TYPE_ICT = 'ict' DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. assert len(data_prefix) % 2 == 0 num_datasets = len(data_prefix) // 2 weights = [0]*num_datasets prefixes = [0]*num_datasets for i in range(num_datasets): weights[i] = float(data_prefix[2*i]) prefixes[i] = (data_prefix[2*i+1]).strip() # Normalize weights weight_sum = 0.0 for weight in weights: weight_sum += weight assert weight_sum > 0.0 weights = [weight / weight_sum for weight in weights] # Add 0.5% (the 1.005 factor) so in case the bleding dataset does # not uniformly distribute the number of samples, we still have # samples left to feed to the network. datasets_train_valid_test_num_samples = [] for weight in weights: datasets_train_valid_test_num_samples.append( [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]) return prefixes, weights, datasets_train_valid_test_num_samples def compile_helper(): """Compile helper function ar runtime. Make sure this is invoked on a single process.""" Loading Loading @@ -360,6 +394,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, short_seq_prob, seed, skip_warmup, dataset_type='standard_bert'): if len(data_prefix) == 1: return _build_train_valid_test_datasets(data_prefix[0], data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, dataset_type=dataset_type) # Blending dataset. # Parse the values. output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. train_datasets = [] valid_datasets = [] test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, dataset_type=dataset_type) # Blend. blending_train_dataset = BlendableDataset(train_datasets, weights) blending_valid_dataset = BlendableDataset(valid_datasets, weights) blending_test_dataset = BlendableDataset(test_datasets, weights) return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, dataset_type='standard_bert'): if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) Loading megatron/data/gpt2_dataset.py +42 −0 Original line number Diff line number Diff line Loading @@ -22,6 +22,8 @@ import numpy as np import torch from megatron import mpu, print_rank_0 from megatron.data.blendable_dataset import BlendableDataset from megatron.data.dataset_utils import get_datasets_weights_and_num_samples from megatron.data.dataset_utils import get_train_valid_test_split_ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset Loading @@ -31,6 +33,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, seq_length, seed, skip_warmup): """Build train, valid, and test datasets.""" # Single dataset. if len(data_prefix) == 1: return _build_train_valid_test_datasets(data_prefix[0], data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup) # Blending dataset. # Parse the values. output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. train_datasets = [] valid_datasets = [] test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], seq_length, seed, skip_warmup) train_datasets.append(train_ds) valid_datasets.append(valid_ds) test_datasets.append(test_ds) # Blend. blending_train_dataset = BlendableDataset(train_datasets, weights) blending_valid_dataset = BlendableDataset(valid_datasets, weights) blending_test_dataset = BlendableDataset(test_datasets, weights) return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup): """Build train, valid, and test datasets.""" # Indexed dataset. indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, Loading megatron/data/helpers.cpp +64 −0 Original line number Diff line number Diff line Loading @@ -33,6 +33,69 @@ using namespace std; const int32_t LONG_SENTENCE_LEN = 512; void build_blending_indices(py::array_t<uint8_t>& dataset_index, py::array_t<int64_t>& dataset_sample_index, const py::array_t<double>& weights, const int32_t num_datasets, const int64_t size, const bool verbose) { /* Given multiple datasets and a weighting array, build samples such that it follows those wieghts.*/ if (verbose) { std::cout << "> building indices for blendable datasets ..." << std::endl; } // Get the pointer access without the checks. auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); auto weights_ptr = weights.unchecked<1>(); // Initialize buffer for number of samples used for each dataset. int64_t current_samples[num_datasets]; for(int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; } // For each sample: for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { // Determine where the max error in sampling is happening. auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0); int64_t max_error_index = 0; double max_error = weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]); for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { double error = weights_ptr[dataset_idx] * sample_idx_double - static_cast<double>(current_samples[dataset_idx]); if (error > max_error) { max_error = error; max_error_index = dataset_idx; } } // Populate the indices. dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index); dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; // Update the total samples. current_samples[max_error_index] += 1; } // print info if (verbose) { std::cout << " > sample ratios:" << std::endl; for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { auto ratio = static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size); std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; } } } py::array build_sample_idx(const py::array_t<int32_t>& sizes_, const py::array_t<int32_t>& doc_idx_, const int32_t seq_length, Loading Loading @@ -640,4 +703,5 @@ PYBIND11_MODULE(helpers, m) { m.def("build_mapping", &build_mapping); m.def("build_blocks_mapping", &build_blocks_mapping); m.def("build_sample_idx", &build_sample_idx); m.def("build_blending_indices", &build_blending_indices); } Loading
megatron/arguments.py +11 −6 Original line number Diff line number Diff line Loading @@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={}, def _print_args(args): """Print arguments.""" if args.rank == 0: print('-------------------- arguments --------------------', flush=True) print('------------------------ arguments ------------------------', flush=True) str_list = [] for arg in vars(args): dots = '.' * (32 - len(arg)) dots = '.' * (48 - len(arg)) str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) for arg in sorted(str_list, key=lambda x: x.lower()): print(arg, flush=True) print('---------------- end of arguments ----------------', flush=True) print('-------------------- end of arguments ---------------------', flush=True) def _check_arg_is_not_none(args, arg): Loading Loading @@ -278,7 +280,7 @@ def _add_learning_rate_args(parser): 'and initial warmup, the learing rate at each ' 'iteration would be different.') group.add_argument('--lr-decay-style', type=str, default='linear', choices=['constant', 'linear', 'cosine', 'exponential'], choices=['constant', 'linear', 'cosine'], help='Learning rate decay function.') group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' Loading Loading @@ -400,8 +402,11 @@ def _add_validation_args(parser): def _add_data_args(parser): group = parser.add_argument_group(title='data and dataloader') group.add_argument('--data-path', type=str, default=None, help='Path to combined dataset to split.') group.add_argument('--data-path', nargs='*', default=None, help='Path to the training dataset. Accepted format:' '1) a single data path, 2) multiple datasets in the' 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' Loading
megatron/data/blendable_dataset.py 0 → 100644 +75 −0 Original line number Diff line number Diff line # coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Blendable dataset.""" import time import numpy as np import torch from megatron import print_rank_0 from megatron import mpu class BlendableDataset(torch.utils.data.Dataset): def __init__(self, datasets, weights): self.datasets = datasets num_datasets = len(datasets) assert num_datasets == len(weights) self.size = 0 for dataset in self.datasets: self.size += len(dataset) # Normalize weights. weights = np.array(weights, dtype=np.float64) sum_weights = np.sum(weights) assert sum_weights > 0.0 weights /= sum_weights # Build indecies. start_time = time.time() assert num_datasets < 255 self.dataset_index = np.zeros(self.size, dtype=np.uint8) self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) if torch.distributed.get_rank() == 0: from megatron.data.dataset_utils import compile_helper compile_helper() # Simple barrier tmp = torch.cuda.LongTensor([1]) torch.distributed.all_reduce(tmp, group=mpu.get_data_parallel_group()) from megatron.data import helpers helpers.build_blending_indices(self.dataset_index, self.dataset_sample_index, weights, num_datasets, self.size, torch.distributed.get_rank() == 0) print_rank_0('> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time)) def __len__(self): return self.size def __getitem__(self, idx): dataset_idx = self.dataset_index[idx] sample_idx = self.dataset_sample_index[idx] return self.datasets[dataset_idx][sample_idx]
megatron/data/dataset_utils.py +74 −0 Original line number Diff line number Diff line Loading @@ -18,11 +18,13 @@ # https://github.com/google-research/albert/blob/master/create_pretraining_data.py # with some modifications. import math import time import collections import numpy as np from megatron import get_args, print_rank_0 from megatron.data.blendable_dataset import BlendableDataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset DSET_TYPE_STD = 'standard_bert' Loading @@ -31,6 +33,38 @@ DSET_TYPE_ICT = 'ict' DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. assert len(data_prefix) % 2 == 0 num_datasets = len(data_prefix) // 2 weights = [0]*num_datasets prefixes = [0]*num_datasets for i in range(num_datasets): weights[i] = float(data_prefix[2*i]) prefixes[i] = (data_prefix[2*i+1]).strip() # Normalize weights weight_sum = 0.0 for weight in weights: weight_sum += weight assert weight_sum > 0.0 weights = [weight / weight_sum for weight in weights] # Add 0.5% (the 1.005 factor) so in case the bleding dataset does # not uniformly distribute the number of samples, we still have # samples left to feed to the network. datasets_train_valid_test_num_samples = [] for weight in weights: datasets_train_valid_test_num_samples.append( [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]) return prefixes, weights, datasets_train_valid_test_num_samples def compile_helper(): """Compile helper function ar runtime. Make sure this is invoked on a single process.""" Loading Loading @@ -360,6 +394,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, short_seq_prob, seed, skip_warmup, dataset_type='standard_bert'): if len(data_prefix) == 1: return _build_train_valid_test_datasets(data_prefix[0], data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, dataset_type=dataset_type) # Blending dataset. # Parse the values. output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. train_datasets = [] valid_datasets = [] test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, dataset_type=dataset_type) # Blend. blending_train_dataset = BlendableDataset(train_datasets, weights) blending_valid_dataset = BlendableDataset(valid_datasets, weights) blending_test_dataset = BlendableDataset(test_datasets, weights) return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup, dataset_type='standard_bert'): if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) Loading
megatron/data/gpt2_dataset.py +42 −0 Original line number Diff line number Diff line Loading @@ -22,6 +22,8 @@ import numpy as np import torch from megatron import mpu, print_rank_0 from megatron.data.blendable_dataset import BlendableDataset from megatron.data.dataset_utils import get_datasets_weights_and_num_samples from megatron.data.dataset_utils import get_train_valid_test_split_ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset Loading @@ -31,6 +33,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, seq_length, seed, skip_warmup): """Build train, valid, and test datasets.""" # Single dataset. if len(data_prefix) == 1: return _build_train_valid_test_datasets(data_prefix[0], data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup) # Blending dataset. # Parse the values. output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. train_datasets = [] valid_datasets = [] test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], seq_length, seed, skip_warmup) train_datasets.append(train_ds) valid_datasets.append(valid_ds) test_datasets.append(test_ds) # Blend. blending_train_dataset = BlendableDataset(train_datasets, weights) blending_valid_dataset = BlendableDataset(valid_datasets, weights) blending_test_dataset = BlendableDataset(test_datasets, weights) return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup): """Build train, valid, and test datasets.""" # Indexed dataset. indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, Loading
megatron/data/helpers.cpp +64 −0 Original line number Diff line number Diff line Loading @@ -33,6 +33,69 @@ using namespace std; const int32_t LONG_SENTENCE_LEN = 512; void build_blending_indices(py::array_t<uint8_t>& dataset_index, py::array_t<int64_t>& dataset_sample_index, const py::array_t<double>& weights, const int32_t num_datasets, const int64_t size, const bool verbose) { /* Given multiple datasets and a weighting array, build samples such that it follows those wieghts.*/ if (verbose) { std::cout << "> building indices for blendable datasets ..." << std::endl; } // Get the pointer access without the checks. auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); auto weights_ptr = weights.unchecked<1>(); // Initialize buffer for number of samples used for each dataset. int64_t current_samples[num_datasets]; for(int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; } // For each sample: for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { // Determine where the max error in sampling is happening. auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0); int64_t max_error_index = 0; double max_error = weights_ptr[0] * sample_idx_double - static_cast<double>(current_samples[0]); for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { double error = weights_ptr[dataset_idx] * sample_idx_double - static_cast<double>(current_samples[dataset_idx]); if (error > max_error) { max_error = error; max_error_index = dataset_idx; } } // Populate the indices. dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index); dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; // Update the total samples. current_samples[max_error_index] += 1; } // print info if (verbose) { std::cout << " > sample ratios:" << std::endl; for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { auto ratio = static_cast<double>(current_samples[dataset_idx]) / static_cast<double>(size); std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; } } } py::array build_sample_idx(const py::array_t<int32_t>& sizes_, const py::array_t<int32_t>& doc_idx_, const int32_t seq_length, Loading Loading @@ -640,4 +703,5 @@ PYBIND11_MODULE(helpers, m) { m.def("build_mapping", &build_mapping); m.def("build_blocks_mapping", &build_blocks_mapping); m.def("build_sample_idx", &build_sample_idx); m.def("build_blending_indices", &build_blending_indices); }