Commit 09e05c6f authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

moved albert to bert

parent 3e4e1ab2
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
from . import indexed_dataset
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .albert_dataset import AlbertDataset
+3 −3
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""ALBERT Style dataset."""
"""BERT Style dataset."""

import os
import time
@@ -79,7 +79,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
            # New doc_idx view.
            indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
            # Build the dataset accordingly.
            dataset = AlbertDataset(
            dataset = BertDataset(
                name=name,
                indexed_dataset=indexed_dataset,
                tokenizer=tokenizer,
@@ -105,7 +105,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
    return (train_dataset, valid_dataset, test_dataset)


class AlbertDataset(Dataset):
class BertDataset(Dataset):

    def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
                 num_epochs, max_num_samples, masked_lm_prob,
+6 −6
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain ALBERT"""
"""Pretrain BERT"""

import torch
import torch.nn.functional as F
@@ -24,7 +24,7 @@ from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
from megatron.data.albert_dataset import build_train_valid_test_datasets
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler


@@ -116,16 +116,16 @@ def get_train_val_test_data(args):
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        print_rank_0('> building train, validation, and test datasets '
                     'for ALBERT ...')
                     'for BERT ...')

        if args.data_loader is None:
            args.data_loader = 'binary'
        if args.data_loader != 'binary':
            print('Unsupported {} data loader for ALBERT.'.format(
            print('Unsupported {} data loader for BERT.'.format(
                args.data_loader))
            exit(1)
        if not args.data_path:
            print('ALBERT only supports a unified dataset specified '
            print('BERT only supports a unified dataset specified '
                  'with --data-path')
            exit(1)

@@ -157,7 +157,7 @@ def get_train_val_test_data(args):
            short_seq_prob=args.short_seq_prob,
            seed=args.seed,
            skip_warmup=args.skip_mmap_warmup)
        print_rank_0("> finished creating ALBERT datasets ...")
        print_rank_0("> finished creating BERT datasets ...")

        def make_data_loader_(dataset):
            if not dataset: