Commit 7d75b3b5 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'staging-lint' into 'staging'

Lint whole repo

See merge request ADLR/megatron-lm!51
parents c99fa80c b7f1b050
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -357,7 +357,6 @@ def _add_gpt2_args(parser):
    return parser



def add_data_args_(parser):
    """Train/valid/test data arguments."""

@@ -367,6 +366,4 @@ def add_data_args_(parser):
                       choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'],
                       help='Which data loader to use. Default varies by model.')


    return parser
+3 −3
Original line number Diff line number Diff line
@@ -67,7 +67,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
        directory = 'iter_{:07d}'.format(iteration)
    return os.path.join(checkpoints_path, directory,
                        'mp_rank_{:02d}'.format(
                            mpu.get_model_parallel_rank() if mp_rank is None \
                            mpu.get_model_parallel_rank() if mp_rank is None
                            else mp_rank),
                        'model_optim_rng.pt')

@@ -179,7 +179,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
            'megatron.fp16.loss_scaler']
        state_dict = torch.load(checkpoint_name, map_location='cpu')
        sys.modules.pop('fp16.loss_scaler', None)
    except:
    except BaseException:
        print_rank_0('could not load the checkpoint')
        sys.exit()

+0 −2
Original line number Diff line number Diff line
from . import indexed_dataset

+4 −6
Original line number Diff line number Diff line
@@ -47,6 +47,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,

    # Print stats about the splits.
    print_rank_0(' > dataset split:')

    def print_split_stats(name, index):
        print_rank_0('    {}:'.format(name))
        print_rank_0('     document indices in [{}, {}) total of {} '
@@ -113,7 +114,6 @@ class BertDataset(Dataset):
        # Dataset.
        self.indexed_dataset = indexed_dataset


        # Build the samples mapping.
        self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
                                                    data_prefix,
@@ -133,11 +133,9 @@ class BertDataset(Dataset):
        self.mask_id = tokenizer.mask
        self.pad_id = tokenizer.pad


    def __len__(self):
        return self.samples_mapping.shape[0]


    def __getitem__(self, idx):

        start_index, end_index, seq_length = self.samples_mapping[idx]
+8 −10
Original line number Diff line number Diff line
@@ -42,6 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,

    # Print stats about the splits.
    print_rank_0(' > dataset split:')

    def print_split_stats(name, index):
        print_rank_0('    {}:'.format(name))
        print_rank_0('     document indices in [{}, {}) total of {} '
@@ -102,13 +103,11 @@ class GPT2Dataset(torch.utils.data.Dataset):
            self.name, data_prefix, documents, self.indexed_dataset.sizes,
            num_samples, seq_length, seed)


    def __len__(self):
        # -1 is due to data structure used to retieve the index:
        #    sample i --> [sample_idx[i], sample_idx[i+1])
        return self.sample_idx.shape[0] - 1


    def __getitem__(self, idx):
        # Get the shuffled index.
        idx = self.shuffle_idx[idx]
@@ -138,7 +137,6 @@ class GPT2Dataset(torch.utils.data.Dataset):
        return {'text': np.array(sample, dtype=np.int64)}



def _build_index_mappings(name, data_prefix, documents, sizes,
                          num_samples, seq_length, seed):
    """Build doc-idx, sample-idx, and shuffle-idx.
Loading