Loading megatron/data/Makefile 0 → 100644 +9 −0 Original line number Diff line number Diff line CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color CPPFLAGS += $(shell python3 -m pybind11 --includes) LIBNAME = helpers LIBEXT = $(shell python3-config --extension-suffix) default: $(LIBNAME)$(LIBEXT) %$(LIBEXT): %.cpp $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ megatron/data/__init__.py +1 −0 Original line number Diff line number Diff line from . import indexed_dataset from .bert_tokenization import FullTokenizer as FullBertTokenizer from .dataset import AlbertDataset megatron/data/dataset.py +30 −10 Original line number Diff line number Diff line Loading @@ -7,27 +7,36 @@ import numpy as np import torch from torch.utils.data import Dataset from dataset_utils import build_training_sample from .dataset_utils import build_training_sample #from data.mapping import build_training_samples_mapping class AlbertDataSet(Dataset): from . import helpers from megatron.data import FullBertTokenizer, indexed_dataset def __init__(self, indexed_dataset, tokenizer, num_epochs, class AlbertDataset(Dataset): def __init__(self, indexed_dataset, tokenizer, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed): # Params to store. self.seed = seed self.masked_lm_prob = masked_lm_prob self.max_seq_length = max_seq_length self.tokenizer = tokenizer # Indexed dataset. self.indexed_dataset = indexed_dataset # Build the samples mapping. self.samples_mapping = build_training_samples_mapping( indexed_dataset, if not max_num_samples: max_num_samples = len(indexed_dataset) * num_epochs self.samples_mapping = helpers.build_mapping( indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, self.max_seq_length, max_num_samples, self.max_seq_length-3, # account for added tokens short_seq_prob, self.seed) Loading @@ -40,8 +49,17 @@ class AlbertDataSet(Dataset): self.pad_id = tokenizer.vocab['[PAD]'] @classmethod def from_paths(cls, vocab, data_prefix, data_impl, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed): tokenizer = FullBertTokenizer(vocab, do_lower_case=True) idx_ds = indexed_dataset.make_dataset(data_prefix, data_impl) return cls(idx_ds, tokenizer, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed) def __len__(self): return self.samples.shape[0] return self.samples_mapping.shape[0] def __getitem__(self, idx): rng = random.Random(self.seed + idx) Loading @@ -49,6 +67,9 @@ class AlbertDataSet(Dataset): sample = [] for index in range(start_index, end_index): sample.append(self.indexed_dataset[index]) for s in sample: if len(s) > 1000: print(self.tokenizer.convert_ids_to_tokens(s)) return build_training_sample(sample, seq_length, self.max_seq_length, self.vocab_id_list, Loading Loading @@ -186,7 +207,6 @@ class JaredDataset(object): if __name__ == '__main__': print('dataset ...') from bert_tokenization import FullTokenizer Loading @@ -207,8 +227,8 @@ if __name__ == '__main__': sentences.extend(sent) yield sentences input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json' vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt' input_file = 'test/samples_10000.json' vocab_file = 'test/vocab.txt' tokenizer = FullTokenizer(vocab_file, do_lower_case=True) document_generator = document_generator_provider(input_file) Loading megatron/data/dataset_utils.py +12 −12 Original line number Diff line number Diff line Loading @@ -35,9 +35,8 @@ def build_training_sample(sample, tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, rng) # Truncate to `target_sequence_length`. # Note that we have account for [CLS] A [SEP] B [SEP] max_num_tokens = target_seq_length - 3 truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens = target_seq_length truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, rng) # Build tokens and toketypes. Loading @@ -48,7 +47,7 @@ def build_training_sample(sample, max_predictions_per_seq = masked_lm_prob * max_num_tokens (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, mask_id, max_predictions_per_seq) cls_id, sep_id, mask_id, max_predictions_per_seq, rng) # Padding. tokens_np, tokentypes_np, labels, padding_mask, loss_mask \ Loading @@ -61,7 +60,8 @@ def build_training_sample(sample, 'labels': labels, 'is_random': int(is_next_random), 'loss_mask': loss_mask, 'padding_mask': padding_mask} 'padding_mask': padding_mask, 'truncated': int(truncated)} return train_sample Loading Loading @@ -99,11 +99,12 @@ def get_a_and_b_segments(sample, rng): def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng): """Truncates a pair of sequences to a maximum sequence length.""" #print(len_a, len_b, max_num_tokens) assert len_a > 0 assert len_b > 0 if (len_a + len_b) <= max_num_tokens: return else: if len_a + len_b <= max_num_tokens: return False while len_a + len_b > max_num_tokens: if len_a > len_b: len_a -= 1 tokens = tokens_a Loading @@ -114,8 +115,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng): del tokens[0] else: tokens.pop() truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng) return True def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" Loading Loading @@ -161,6 +161,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob, cls_id, sep_id, mask_id, max_predictions_per_seq, rng, max_ngrams=3, do_whole_word_mask=True, favor_longer_ngram=False, Loading Loading @@ -468,4 +469,3 @@ if __name__ == '__main__': string += '{:5d}'.format(tokentype) string += '{:5d}'.format(padding_mask) print(string) megatron/data/helpers.cpp +188 −175 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ #include <iostream> #include <limits> #include <math.h> #include <stdexcept> #include <pybind11/pybind11.h> #include <pybind11/numpy.h> Loading @@ -20,11 +21,11 @@ inline uint32_t get_sample_len(const int short_seq_ratio, return max_length; } py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, template<typename DocIdx> py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, const py::array_t<uint16_t>& sizes_, const int num_epochs, const int max_num_samples, const uint64_t max_num_samples, const int max_seq_length, const double short_seq_prob, const int seed) { Loading @@ -33,7 +34,7 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, " documents with " << sizes_.shape(0) << " sentences ..." << endl; // For efficiency, convert probability to ratio. const int short_seq_ratio = int(round(1.0 / short_seq_prob)); const auto short_seq_ratio = static_cast<int>(round(1.0 / short_seq_prob)); // Remove bound checks. auto docs = docs_.unchecked<1>(); Loading @@ -47,8 +48,8 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, } // Mapping and it's length (1D). int num_samples = -1; uint32_t* maps = NULL; int64_t num_samples = -1; DocIdx* maps = NULL; // Perform two iterations, in the first iteration get the size // and allocate memory and in the second iteration populate the map. Loading @@ -59,9 +60,7 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, srand(seed); // Set the flag on second iteration. if (iteration == 1) { second = true; } second = iteration == 1; // Counters: uint32_t empty_docs = 0; Loading @@ -72,7 +71,7 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, // For each epoch: for (int epoch=0; epoch < num_epochs; ++epoch) { if (map_index >= max_num_samples) { if (map_index >= max_num_samples && !second) { cout << " > reached " << max_num_samples << " samples after " << epoch << " epochs ..." << endl; break; Loading @@ -81,14 +80,14 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, for (int doc=0; doc < (docs.shape(0) - 1); ++doc) { // Document sentences are in [sent_index_first, sent_index_last). const uint32_t sent_index_first = docs[doc]; const uint32_t sent_index_last = docs[doc + 1]; const auto sent_index_first = docs[doc]; const auto sent_index_last = docs[doc + 1]; // At the begining of the document previous index is the start index. uint32_t prev_start_index = sent_index_first; auto prev_start_index = sent_index_first; // Remaining documents. uint32_t num_remain_sent = sent_index_last - sent_index_first; auto num_remain_sent = sent_index_last - sent_index_first; // Some bookkeeping if ((epoch == 0) && (!second)) { Loading @@ -107,12 +106,12 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, if (num_remain_sent > 1) { // Set values. uint32_t size = 0; uint32_t num_sent = 0; uint32_t seq_len = get_sample_len(short_seq_ratio, max_seq_length); auto size = uint32_t{0}; auto num_sent = uint32_t{0}; auto seq_len = get_sample_len(short_seq_ratio, max_seq_length); // Loop through sentences. for (uint32_t sent_index=sent_index_first; for (auto sent_index=sent_index_first; sent_index < sent_index_last; ++sent_index) { // Add the size and number of sentences. Loading @@ -129,13 +128,19 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, // Populate the map. if (second) { const uint64_t map_index_0 = 3 * map_index; const auto map_index_0 = 3 * map_index; maps[map_index_0] = prev_start_index; maps[map_index_0 + 1] = sent_index + 1; maps[map_index_0 + 2] = seq_len; } // Update indices / counters. // check for overflow if (map_index == std::numeric_limits<DocIdx>::max()) { cout << "number of samples exceeded maximum allowed by type: " << std::numeric_limits<DocIdx>::max() << endl; throw std::overflow_error("Number of samples"); } map_index += 1; prev_start_index = sent_index + 1; seq_len = get_sample_len(short_seq_ratio, max_seq_length); Loading @@ -148,29 +153,24 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, } // for (int doc=0; doc < num_docs; ++doc) { } // for (int epoch=0; epoch < num_epochs; ++epoch) { // For now only support mappings up to MAX_INT. if (map_index > std::numeric_limits<int>::max()) { cout << "number of samples ("<< map_index <<") exceeded MAX_INT" << endl; throw(-1); } else if (!second) { if (!second) { cout << " number of samples: " << map_index << endl; cout << " number of empty documents: " << empty_docs << endl; cout << " number of documents with one sentence: " << one_sent_docs << endl; maps = new uint32_t[3*map_index]; num_samples = int(map_index); maps = new DocIdx[3*map_index]; num_samples = map_index; } } // for (int iteration=0; iteration < 2; ++iteration) { // Shuffle. for (int i=(num_samples - 1); i > 0; --i) { const int j = rand() % (i + 1); uint64_t i0 = 3 * i; uint64_t j0 = 3 * j; for (auto i=(num_samples - 1); i > 0; --i) { const auto j = rand() % (i + 1); const auto i0 = 3 * i; const auto j0 = 3 * j; // Swap values. swap(maps[i0], maps[j0]); swap(maps[i0 + 1], maps[j0 + 1]); Loading @@ -181,22 +181,35 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, // Method to deallocate memory. py::capsule free_when_done(maps, [](void *mem_) { uint32_t *mem = reinterpret_cast<uint32_t *>(mem_); DocIdx *mem = reinterpret_cast<DocIdx*>(mem_); cout << "freeing memory for the dataset mapping" << endl; delete[] mem; }); // Return the numpy array. return py::array_t<uint32_t>({num_samples, 3}, // shape return py::array(std::vector<int64_t>{num_samples, 3}, // shape {3*4, 4}, // C-style contiguous strides maps, // the data pointer free_when_done); // numpy array references } py::array build_mapping(const py::array& docs_, const py::array& sizes_, const int num_epochs, const uint64_t max_num_samples, const int max_seq_length, const double short_seq_prob, const int seed) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed); } else { return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed); } } PYBIND11_MODULE(helpers, m) { m.def("build_mapping", &build_mapping); } Loading
megatron/data/Makefile 0 → 100644 +9 −0 Original line number Diff line number Diff line CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color CPPFLAGS += $(shell python3 -m pybind11 --includes) LIBNAME = helpers LIBEXT = $(shell python3-config --extension-suffix) default: $(LIBNAME)$(LIBEXT) %$(LIBEXT): %.cpp $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
megatron/data/__init__.py +1 −0 Original line number Diff line number Diff line from . import indexed_dataset from .bert_tokenization import FullTokenizer as FullBertTokenizer from .dataset import AlbertDataset
megatron/data/dataset.py +30 −10 Original line number Diff line number Diff line Loading @@ -7,27 +7,36 @@ import numpy as np import torch from torch.utils.data import Dataset from dataset_utils import build_training_sample from .dataset_utils import build_training_sample #from data.mapping import build_training_samples_mapping class AlbertDataSet(Dataset): from . import helpers from megatron.data import FullBertTokenizer, indexed_dataset def __init__(self, indexed_dataset, tokenizer, num_epochs, class AlbertDataset(Dataset): def __init__(self, indexed_dataset, tokenizer, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed): # Params to store. self.seed = seed self.masked_lm_prob = masked_lm_prob self.max_seq_length = max_seq_length self.tokenizer = tokenizer # Indexed dataset. self.indexed_dataset = indexed_dataset # Build the samples mapping. self.samples_mapping = build_training_samples_mapping( indexed_dataset, if not max_num_samples: max_num_samples = len(indexed_dataset) * num_epochs self.samples_mapping = helpers.build_mapping( indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, self.max_seq_length, max_num_samples, self.max_seq_length-3, # account for added tokens short_seq_prob, self.seed) Loading @@ -40,8 +49,17 @@ class AlbertDataSet(Dataset): self.pad_id = tokenizer.vocab['[PAD]'] @classmethod def from_paths(cls, vocab, data_prefix, data_impl, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed): tokenizer = FullBertTokenizer(vocab, do_lower_case=True) idx_ds = indexed_dataset.make_dataset(data_prefix, data_impl) return cls(idx_ds, tokenizer, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed) def __len__(self): return self.samples.shape[0] return self.samples_mapping.shape[0] def __getitem__(self, idx): rng = random.Random(self.seed + idx) Loading @@ -49,6 +67,9 @@ class AlbertDataSet(Dataset): sample = [] for index in range(start_index, end_index): sample.append(self.indexed_dataset[index]) for s in sample: if len(s) > 1000: print(self.tokenizer.convert_ids_to_tokens(s)) return build_training_sample(sample, seq_length, self.max_seq_length, self.vocab_id_list, Loading Loading @@ -186,7 +207,6 @@ class JaredDataset(object): if __name__ == '__main__': print('dataset ...') from bert_tokenization import FullTokenizer Loading @@ -207,8 +227,8 @@ if __name__ == '__main__': sentences.extend(sent) yield sentences input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json' vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt' input_file = 'test/samples_10000.json' vocab_file = 'test/vocab.txt' tokenizer = FullTokenizer(vocab_file, do_lower_case=True) document_generator = document_generator_provider(input_file) Loading
megatron/data/dataset_utils.py +12 −12 Original line number Diff line number Diff line Loading @@ -35,9 +35,8 @@ def build_training_sample(sample, tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, rng) # Truncate to `target_sequence_length`. # Note that we have account for [CLS] A [SEP] B [SEP] max_num_tokens = target_seq_length - 3 truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens = target_seq_length truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, rng) # Build tokens and toketypes. Loading @@ -48,7 +47,7 @@ def build_training_sample(sample, max_predictions_per_seq = masked_lm_prob * max_num_tokens (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, mask_id, max_predictions_per_seq) cls_id, sep_id, mask_id, max_predictions_per_seq, rng) # Padding. tokens_np, tokentypes_np, labels, padding_mask, loss_mask \ Loading @@ -61,7 +60,8 @@ def build_training_sample(sample, 'labels': labels, 'is_random': int(is_next_random), 'loss_mask': loss_mask, 'padding_mask': padding_mask} 'padding_mask': padding_mask, 'truncated': int(truncated)} return train_sample Loading Loading @@ -99,11 +99,12 @@ def get_a_and_b_segments(sample, rng): def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng): """Truncates a pair of sequences to a maximum sequence length.""" #print(len_a, len_b, max_num_tokens) assert len_a > 0 assert len_b > 0 if (len_a + len_b) <= max_num_tokens: return else: if len_a + len_b <= max_num_tokens: return False while len_a + len_b > max_num_tokens: if len_a > len_b: len_a -= 1 tokens = tokens_a Loading @@ -114,8 +115,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng): del tokens[0] else: tokens.pop() truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng) return True def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" Loading Loading @@ -161,6 +161,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob, cls_id, sep_id, mask_id, max_predictions_per_seq, rng, max_ngrams=3, do_whole_word_mask=True, favor_longer_ngram=False, Loading Loading @@ -468,4 +469,3 @@ if __name__ == '__main__': string += '{:5d}'.format(tokentype) string += '{:5d}'.format(padding_mask) print(string)
megatron/data/helpers.cpp +188 −175 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ #include <iostream> #include <limits> #include <math.h> #include <stdexcept> #include <pybind11/pybind11.h> #include <pybind11/numpy.h> Loading @@ -20,11 +21,11 @@ inline uint32_t get_sample_len(const int short_seq_ratio, return max_length; } py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, template<typename DocIdx> py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, const py::array_t<uint16_t>& sizes_, const int num_epochs, const int max_num_samples, const uint64_t max_num_samples, const int max_seq_length, const double short_seq_prob, const int seed) { Loading @@ -33,7 +34,7 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, " documents with " << sizes_.shape(0) << " sentences ..." << endl; // For efficiency, convert probability to ratio. const int short_seq_ratio = int(round(1.0 / short_seq_prob)); const auto short_seq_ratio = static_cast<int>(round(1.0 / short_seq_prob)); // Remove bound checks. auto docs = docs_.unchecked<1>(); Loading @@ -47,8 +48,8 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, } // Mapping and it's length (1D). int num_samples = -1; uint32_t* maps = NULL; int64_t num_samples = -1; DocIdx* maps = NULL; // Perform two iterations, in the first iteration get the size // and allocate memory and in the second iteration populate the map. Loading @@ -59,9 +60,7 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, srand(seed); // Set the flag on second iteration. if (iteration == 1) { second = true; } second = iteration == 1; // Counters: uint32_t empty_docs = 0; Loading @@ -72,7 +71,7 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, // For each epoch: for (int epoch=0; epoch < num_epochs; ++epoch) { if (map_index >= max_num_samples) { if (map_index >= max_num_samples && !second) { cout << " > reached " << max_num_samples << " samples after " << epoch << " epochs ..." << endl; break; Loading @@ -81,14 +80,14 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, for (int doc=0; doc < (docs.shape(0) - 1); ++doc) { // Document sentences are in [sent_index_first, sent_index_last). const uint32_t sent_index_first = docs[doc]; const uint32_t sent_index_last = docs[doc + 1]; const auto sent_index_first = docs[doc]; const auto sent_index_last = docs[doc + 1]; // At the begining of the document previous index is the start index. uint32_t prev_start_index = sent_index_first; auto prev_start_index = sent_index_first; // Remaining documents. uint32_t num_remain_sent = sent_index_last - sent_index_first; auto num_remain_sent = sent_index_last - sent_index_first; // Some bookkeeping if ((epoch == 0) && (!second)) { Loading @@ -107,12 +106,12 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, if (num_remain_sent > 1) { // Set values. uint32_t size = 0; uint32_t num_sent = 0; uint32_t seq_len = get_sample_len(short_seq_ratio, max_seq_length); auto size = uint32_t{0}; auto num_sent = uint32_t{0}; auto seq_len = get_sample_len(short_seq_ratio, max_seq_length); // Loop through sentences. for (uint32_t sent_index=sent_index_first; for (auto sent_index=sent_index_first; sent_index < sent_index_last; ++sent_index) { // Add the size and number of sentences. Loading @@ -129,13 +128,19 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, // Populate the map. if (second) { const uint64_t map_index_0 = 3 * map_index; const auto map_index_0 = 3 * map_index; maps[map_index_0] = prev_start_index; maps[map_index_0 + 1] = sent_index + 1; maps[map_index_0 + 2] = seq_len; } // Update indices / counters. // check for overflow if (map_index == std::numeric_limits<DocIdx>::max()) { cout << "number of samples exceeded maximum allowed by type: " << std::numeric_limits<DocIdx>::max() << endl; throw std::overflow_error("Number of samples"); } map_index += 1; prev_start_index = sent_index + 1; seq_len = get_sample_len(short_seq_ratio, max_seq_length); Loading @@ -148,29 +153,24 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, } // for (int doc=0; doc < num_docs; ++doc) { } // for (int epoch=0; epoch < num_epochs; ++epoch) { // For now only support mappings up to MAX_INT. if (map_index > std::numeric_limits<int>::max()) { cout << "number of samples ("<< map_index <<") exceeded MAX_INT" << endl; throw(-1); } else if (!second) { if (!second) { cout << " number of samples: " << map_index << endl; cout << " number of empty documents: " << empty_docs << endl; cout << " number of documents with one sentence: " << one_sent_docs << endl; maps = new uint32_t[3*map_index]; num_samples = int(map_index); maps = new DocIdx[3*map_index]; num_samples = map_index; } } // for (int iteration=0; iteration < 2; ++iteration) { // Shuffle. for (int i=(num_samples - 1); i > 0; --i) { const int j = rand() % (i + 1); uint64_t i0 = 3 * i; uint64_t j0 = 3 * j; for (auto i=(num_samples - 1); i > 0; --i) { const auto j = rand() % (i + 1); const auto i0 = 3 * i; const auto j0 = 3 * j; // Swap values. swap(maps[i0], maps[j0]); swap(maps[i0 + 1], maps[j0 + 1]); Loading @@ -181,22 +181,35 @@ py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, // Method to deallocate memory. py::capsule free_when_done(maps, [](void *mem_) { uint32_t *mem = reinterpret_cast<uint32_t *>(mem_); DocIdx *mem = reinterpret_cast<DocIdx*>(mem_); cout << "freeing memory for the dataset mapping" << endl; delete[] mem; }); // Return the numpy array. return py::array_t<uint32_t>({num_samples, 3}, // shape return py::array(std::vector<int64_t>{num_samples, 3}, // shape {3*4, 4}, // C-style contiguous strides maps, // the data pointer free_when_done); // numpy array references } py::array build_mapping(const py::array& docs_, const py::array& sizes_, const int num_epochs, const uint64_t max_num_samples, const int max_seq_length, const double short_seq_prob, const int seed) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed); } else { return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed); } } PYBIND11_MODULE(helpers, m) { m.def("build_mapping", &build_mapping); }