Commit f86bb671 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

checked and bert, gpt, and albert albert run

parent d6485684
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -278,7 +278,7 @@ def add_data_args(parser):
                       help='path(s) to the validation data.')
    group.add_argument('--test-data', nargs='*', default=None,
                       help='path(s) to the testing data.')
    group.add_argument('--data-path', type=str, default=None,
    group.add_argument('--data-path', nargs='+', default=None,
                       help='path to combined dataset to split')
    group.add_argument('--split', default='1000,1,1',
                       help='comma-separated list of proportions for training,'
+2 −0
Original line number Diff line number Diff line
@@ -131,6 +131,8 @@ def make_loaders(args):
    if eval_seq_length is not None and eval_seq_length < 0:
        eval_seq_length = eval_seq_length * world_size
    split = get_split(args)
    if args.data_path is not None:
        args.train_data = args.data_path
    data_set_args = {
        'path': args.train_data,
        'seq_length': seq_length,
+1 −1
Original line number Diff line number Diff line
@@ -57,7 +57,7 @@ def make_gpt2_dataloaders(args):
                                           pin_memory=True)

    train = make_data_loader_(args.train_data)
    valid = make_data_loader_(args.val_data)
    valid = make_data_loader_(args.valid_data)
    test = make_data_loader_(args.test_data)

    args.do_train = False
+2 −1
Original line number Diff line number Diff line
@@ -143,9 +143,10 @@ def get_train_val_test_data(args):
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        assert len(args.data_path) == 1
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
            vocab_file=args.vocab,
            data_prefix=args.data_path,
            data_prefix=args.data_path[0],
            data_impl=args.data_impl,
            splits_string=args.split,
            train_valid_test_num_samples=train_val_test_num_samples,
+7 −1
Original line number Diff line number Diff line
@@ -113,9 +113,15 @@ def get_train_val_test_data(args):
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        if args.data_loader == 'numpy':
            assert len(args.train_data) == 1
            args.train_data = args.train_data[0]
            assert len(args.valid_data) == 1
            args.valid_data = args.valid_data[0]
            assert len(args.test_data) == 1
            args.test_data = args.test_data[0]
            (train_data, val_data, test_data), num_tokens, \
                eod_token = make_gpt2_dataloaders(args)
        elif args.data_loader == 'raw' or args.data_loader == 'lazy'
        elif args.data_loader == 'raw' or args.data_loader == 'lazy':
            data_config = configure_data()
            data_config.set_defaults(data_set_type='GPT2', transpose=False)
            (train_data, val_data, test_data), tokenizer = data_config.apply(