Commit eb74fa34 authored by Mohammad's avatar Mohammad
Browse files

fixed a bug related to last index in the shuffled_idx

parent 5f174c07
Loading
Loading
Loading
Loading
+6 −90
Original line number Diff line number Diff line
@@ -104,7 +104,9 @@ class GPT2Dataset(torch.utils.data.Dataset):


    def __len__(self):
        return self.sample_idx.shape[0]
        # -1 is due to data structure used to retieve the index:
        #    sample i --> [sample_idx[i], sample_idx[i+1])
        return self.sample_idx.shape[0] - 1


    def __getitem__(self, idx):
@@ -185,7 +187,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
                         '(seconds): {:4f}'.format(time.time() - start_time))
            # shuffle-idx.
            start_time = time.time()
            shuffle_idx = _build_shuffle_idx(sample_idx.shape[0], np_rng)
            # -1 is due to data structure used to retieve the index:
            #    sample i --> [sample_idx[i], sample_idx[i+1])
            shuffle_idx = _build_shuffle_idx(sample_idx.shape[0]-1, np_rng)
            np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
            print_rank_0(' > elasped time to build and save shuffle-idx mapping'
                         ' (seconds): {:4f}'.format(time.time() - start_time))
@@ -306,91 +310,3 @@ def _build_shuffle_idx(size, np_rng):
    shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
    np_rng.shuffle(shuffle_idx)
    return shuffle_idx


'''

class IndexedDataset:

    def __init__(self, num_docs, min_doc_length, max_doc_length, seq_length):

        self.seq_length = seq_length
        assert min_doc_length > 0

        self.tokens = []
        self.sizes = np.zeros(num_docs, dtype=np.int32)
        for i in range(num_docs):
            size = np.random.randint(low=min_doc_length, high=max_doc_length,
                                     size=1, dtype=np.uint32)[0]
            tokens_ = np.random.randint(low=1, high=60000,
                                        size=size, dtype=np.uint32)
            tokens_[-1] = 0
            self.sizes[i] = size
            self.tokens.append(tokens_)

        self.tokens_flat = None

    def get(self, doc_idx, offset=None, length=None):
        if length is None:
            if offset is None:
                return self.tokens[doc_idx]
            else:
                return self.tokens[doc_idx][offset:]
        if offset is None:
            return self.tokens[doc_idx][0:length]
        return self.tokens[doc_idx][offset:(offset+length)]

    def get_sample(self, index):
        start = index * self.seq_length
        end = start + self.seq_length + 1
        return self.tokens_flat[start:end]

    def build_tokens_flat(self, doc_idx):
        self.tokens_flat = np.concatenate([self.tokens[i] for i in doc_idx])


def test(seed, data_prefix, seq_length, num_samples,
         num_docs, min_doc_length, max_doc_length):

    print('testing for seed: {}, seq-length: {}, num-samples: {}, '
          'num-docs: {}, min-doc-length: {}, max-doc-length: {}'.format(
              seed, seq_length, num_samples,
              num_docs, min_doc_length, max_doc_length))
    np.random.seed(seed)

    indexed_dataset = IndexedDataset(num_docs, min_doc_length,
                                     max_doc_length, seq_length)
    indices = np.random.randint(indexed_dataset.sizes.shape[0]-2, size=2)
    documents = np.arange(np.min(indices), np.max(indices)+1)
    dataset = GPT2Dataset('gpt2', data_prefix, documents, indexed_dataset,
                          num_samples, seq_length, seed)

    print(' > number of epochs:', dataset.num_epochs)
    indexed_dataset.build_tokens_flat(dataset.doc_idx)

    for idx in range(num_samples):
        a = dataset[idx]
        b = indexed_dataset.get_sample(idx)
        assert np.sum(a - b) == 0

    print('passed')
    

if __name__ == '__main__':

    print('gpt2 dataset ...')


    import random
    data_prefix = 'junk/'
    for seed in range(1234, 1245):
        random.seed(seed)
        num_docs = random.randint(1, 999)
        min_doc_length = random.randint(1, 99)
        max_doc_length = random.randint(100, 9999)
        num_samples = random.randint(num_docs, 100*num_docs)
        seq_length = random.randint(min_doc_length, max_doc_length)

        test(seed, data_prefix, seq_length, num_samples,
             num_docs, min_doc_length, max_doc_length)
'''