Commit 9993ea25 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'refactor_utils' into 'master'

moved few common elements between bert and gpt to utils

See merge request ADLR/megatron-lm!4
parents c882ac61 787c1a0b
Loading
Loading
Loading
Loading
+57 −2
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ import numpy as np
import torch

from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.fp16 import FP16_Optimizer
from megatron import mpu
from megatron import model
@@ -183,13 +184,67 @@ def report_memory(name):
        torch.cuda.max_memory_cached()/ mega_bytes)
    print_rank_0(string)

def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):

def initialize_distributed(args):
    """Initialize torch.distributed."""

    # Manually set the device ids.
    device = args.rank % torch.cuda.device_count()
    if args.local_rank is not None:
        device = args.local_rank
    torch.cuda.set_device(device)
    # Call the init process
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
        init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    mpu.initialize_model_parallel(args.model_parallel_size)


def wrap_model_for_distributed_training(model, args):
    """Wrap model for distributed training."""
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
        args.DDP_type = torchDDP
        model = args.DDP_type(model, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
        return model
    elif args.DDP_impl == 'local':
        args.DDP_type = LocalDDP
        model = args.DDP_type(model)
        return model
    else:
        print_rank_0('Unknown DDP implementation specified: {}. '
                     'Exiting.'.format(args.DDP_impl))
        exit()


def set_random_seed(seed):
    """Set random seed for reproducability."""

    if seed is not None and seed > 0:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        mpu.model_parallel_cuda_manual_seed(seed)


def get_checkpoint_name(checkpoints_path, iteration, release=False,
                        mp_rank=None):
    if release:
        d = 'release'
    else:
        d = 'iter_{:07d}'.format(iteration)
    return os.path.join(checkpoints_path, d,
                        'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank() if mp_rank is None else mp_rank),
                        'mp_rank_{:02d}'.format(
                            mpu.get_model_parallel_rank() if mp_rank is None \
                            else mp_rank),
                        'model_optim_rng.pt')


+5 −45
Original line number Diff line number Diff line
@@ -30,7 +30,6 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import BertModel
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as LocalDDP
from megatron import mpu
from apex.optimizers import FusedAdam as Adam
from megatron.utils import Timers
@@ -42,6 +41,10 @@ from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training


def get_model(args):
    """Build the model."""
@@ -72,18 +75,7 @@ def get_model(args):
                    _module.float()

    # Wrap model for distributed training.
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
        args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
        model = args.DDP_type(model, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
    elif args.DDP_impl == 'local':
        args.DDP_type = LocalDDP
        model = args.DDP_type(model)
    else:
        print_rank_0('Unknown DDP implementation specified: {}. '
                     'Exiting.'.format(args.DDP_impl))
        exit()
    model = wrap_model_for_distributed_training(model, args)

    return model

@@ -474,38 +466,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
    return val_loss


def initialize_distributed(args):
    """Initialize torch.distributed."""

    # Manually set the device ids.
    device = args.rank % torch.cuda.device_count()
    if args.local_rank is not None:
        device = args.local_rank
    torch.cuda.set_device(device)
    # Call the init process
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
        init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    mpu.initialize_model_parallel(args.model_parallel_size)


def set_random_seed(seed):
    """Set random seed for reproducability."""

    if seed is not None and seed > 0:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        mpu.model_parallel_cuda_manual_seed(seed)


def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

+4 −45
Original line number Diff line number Diff line
@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import GPT2Model
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as LocalDDP
from megatron import mpu
from apex.optimizers import FusedAdam as Adam
from megatron.utils import Timers
@@ -41,6 +40,9 @@ from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training

from gpt2_data_loader import make_gpt2_dataloaders

@@ -73,18 +75,7 @@ def get_model(args):
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
        args.DDP_type = torch.nn.parallel.distributed.DistributedDataParallel
        model = args.DDP_type(model, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
    elif args.DDP_impl == 'local':
        args.DDP_type = LocalDDP
        model = args.DDP_type(model)
    else:
        print_rank_0('Unknown DDP implementation specified: {}. '
                     'Exiting.'.format(args.DDP_impl))
        exit()
    model = wrap_model_for_distributed_training(model, args)

    return model

@@ -500,38 +491,6 @@ def evaluate_and_print_results(prefix, data_iterator, model,
    return lm_loss


def initialize_distributed(args):
    """Initialize torch.distributed."""

    # Manually set the device ids.
    device = args.rank % torch.cuda.device_count()
    if args.local_rank is not None:
        device = args.local_rank
    torch.cuda.set_device(device)
    # Call the init process
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port
    torch.distributed.init_process_group(
        backend=args.distributed_backend,
        world_size=args.world_size, rank=args.rank,
        init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    mpu.initialize_model_parallel(args.model_parallel_size)


def set_random_seed(seed):
    """Set random seed for reproducability."""

    if seed is not None and seed > 0:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        mpu.model_parallel_cuda_manual_seed(seed)


def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""