Commit 03d0d93d authored by Mohammad's avatar Mohammad
Browse files

refactored merge_mp_partitions.py

parent 5655f076
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -65,7 +65,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}):
    """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
    args = _parse_args(extra_args_provider=extra_args_provider,
                       defaults=args_defaults)
    _build_tokenizer(args)
    _ = _build_tokenizer(args)
    _set_tensorboard_writer(args)
    _set_adlr_autoresume(args)
    _set_timers()
@@ -85,6 +85,13 @@ def _build_tokenizer(args):
    global _GLOBAL_TOKENIZER
    _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
    _GLOBAL_TOKENIZER = build_tokenizer(args)
    return _GLOBAL_TOKENIZER


def rebuild_tokenizer(args):
    global _GLOBAL_TOKENIZER
    _GLOBAL_TOKENIZER = None
    return _build_tokenizer(args)


def _set_tensorboard_writer(args):
+5 −0
Original line number Diff line number Diff line
@@ -15,6 +15,11 @@

"""Sample Generate GPT2"""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
+64 −21
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Merge model parallel partitions."""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

import torch

from arguments import get_args
from megatron import mpu
from megatron.utils import ensure_directory_exists
from megatron.utils import get_checkpoint_name
from megatron.utils import get_checkpoint_tracker_filename
from megatron.utils import vocab_size_with_padding
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import rebuild_tokenizer
from megatron.global_vars import _parse_args


def split_into_partitions(tensor, num_partitions, partition_dim, stride):
@@ -84,21 +104,26 @@ def merge_partitions(merged, partitions, partition_dim, stride):
    return


def get_model(model_type, args):
def get_model(model_type):

    if model_type == 'BERT':
        from pretrain_albert import model_provider
        args.tokentype_size = 2
    elif  model_type == 'GPT':
        from pretrain_bert import model_provider
    elif model_type == 'GPT2':
        from pretrain_gpt2 import model_provider
    elif model_type == 'RACE':
        from tasks.race.finetune import model_provider
    elif model_type == ['MNLI', 'QQP']:
        num_classes = 2
        if model_type == 'MNLI':
            num_classes = 3
        from megatron.model.classification import Classification
        def model_provider():
            return Classification(num_classes=num_classes, num_tokentypes=2)
    else:
        raise Exception('unrecognized model type: {}'.format(model_type))

    orig_vocab_size = args.vocab_size
    args.vocab_size = vocab_size_with_padding(args.vocab_size, args)
    model = model_provider(args)
    model = model_provider()
    model = model.half()
    args.vocab_size = orig_vocab_size

    return model

@@ -147,17 +172,32 @@ def test_split_merge():
    print('  > max error (should be zero): {}'.format(max_error))


def main(model_type):
def get_mp_merge_args(parser):
    """Provide extra arguments required for merging."""
    group = parser.add_argument_group(title='mp merge')

    group.add_argument('--model-type', type=str, required=True,
                       choices=['BERT', 'GPT2', 'RACE', 'MNLI', 'QQP'],
                       help='Type of the mdoel.')

    return parser


def main():

    # Args
    args = get_args()
    args = _parse_args(extra_args_provider=get_mp_merge_args)
    model_type = args.model_type
    orig_model_parallel_size = args.model_parallel_size
    args.model_parallel_size = 1
    tokenizer = rebuild_tokenizer(args)

    print('\n merging model parallel partitions ...')
    assert args.vocab_size is not None
    print(' > number of partitions: {}'.format(args.model_parallel_size))
    print(' > number of partitions: {}'.format(orig_model_parallel_size))
    print(' > checkpoint path: {}'.format(args.load))
    print(' > model parameters:')
    print('    number of tokens ................ {} '.format(args.vocab_size))
    print('    number of tokens ................ {} '.format(
        tokenizer.vocab_size))
    print('    number of layers ................ {}'.format(args.num_layers))
    print('    hidden sise ..................... {}'.format(args.hidden_size))
    print('    number of attention heads ....... {}'.format(
@@ -169,17 +209,19 @@ def main(model_type):
    print('> building the full model ...')
    mpu.initialize.set_model_parallel_world_size(1)
    mpu.initialize.set_model_parallel_rank(0)
    merged_model = get_model(model_type, args)
    merged_model = get_model(model_type)

    # Build and load partitions.
    partitions = []
    iteration = 0
    args.model_parallel_size = orig_model_parallel_size
    tokenizer = rebuild_tokenizer(args)
    mpu.initialize.set_model_parallel_world_size(args.model_parallel_size)
    for rank in range(args.model_parallel_size):
        mpu.initialize.set_model_parallel_rank(rank)
        checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
        print('> loading {} ...'.format(checkpoint_name))
        model_ = get_model(model_type, args)
        model_ = get_model(model_type)
        sd = torch.load(checkpoint_name, map_location='cpu')
        model_.load_state_dict(sd['model'])
        partitions.append(model_)
@@ -225,6 +267,7 @@ def main(model_type):


    # Save the model.
    args.model_parallel_size = 1
    mpu.initialize.set_model_parallel_rank(0)
    sd = {}
    sd['model'] = merged_model.state_dict_for_save_checkpoint()
@@ -240,4 +283,4 @@ def main(model_type):

if __name__ == '__main__':

    main('BERT')
    main()
+21 −2
Original line number Diff line number Diff line
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Processing data for pretraining."""

import argparse
import json
import multiprocessing
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))
import time



import torch
try:
    import nltk
@@ -16,6 +34,7 @@ except ImportError:
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset


# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):