Commit d6485684 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

fixed gpt-2 dataloder

parent b8e0129f
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -115,7 +115,7 @@ def get_train_val_test_data(args):
        if args.data_loader == 'numpy':
            (train_data, val_data, test_data), num_tokens, \
                eod_token = make_gpt2_dataloaders(args)
        elif args.data_loader == 'raw' or args.data_loader == 'tfrecords'
        elif args.data_loader == 'raw' or args.data_loader == 'lazy'
            data_config = configure_data()
            data_config.set_defaults(data_set_type='GPT2', transpose=False)
            (train_data, val_data, test_data), tokenizer = data_config.apply(
@@ -123,6 +123,9 @@ def get_train_val_test_data(args):
            num_tokens = tokenizer.num_tokens
            eod_token = tokenizer.get_command('eos').Id
            assert eod_token == tokenizer.get_command('pad').Id
        else:
            print("Unsupported data loader for GPT2.")
            exit(1)
        # pad.
        num_tokens = vocab_size_with_padding(num_tokens, args)
        print_rank_0('> found end-of-document token: {}'.format(eod_token))