Commit cf2f4d9d authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'code_reuse' into 'master'

refactored for code reuse

See merge request ADLR/megatron-lm!11
parents beb3e0d3 cbd8c054
Loading
Loading
Loading
Loading
+84 −63
Original line number Diff line number Diff line
@@ -71,28 +71,17 @@ def run(top_level_message, train_val_test_data_provider,
            function add `batch generator` to the timers class.
    """

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Timer.
    timers = Timers()

    # Tensorboard writer
    writer = get_tensorboard_writer(args)

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print(top_level_message, flush=True)
        print_args(args, writer)

    # Autoresume.
    torch.distributed.barrier()
    if args.adlr_autoresume:
        enable_adlr_autoresume(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)
    # Initalize.
    initialize_megatron(top_level_message, args, writer)

    # Data stuff.
    train_data, val_data, test_data = train_val_test_data_provider(args)
@@ -135,6 +124,24 @@ def run(top_level_message, train_val_test_data_provider,
                                   args, None, 0, timers, True)


def initialize_megatron(message, args, writer):
    """"Initialize distributed, random seed, and autoresume."""

    # Pytorch distributed.
    initialize_distributed(args)
    if torch.distributed.get_rank() == 0:
        print(message, flush=True)
        print_args(args, writer)

    # Autoresume.
    torch.distributed.barrier()
    if args.adlr_autoresume:
        enable_adlr_autoresume(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)


def get_model(model_provider_func, args):
    """Build the model."""

@@ -301,53 +308,31 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
    return loss_reduced, skipped_iter


def train(forward_step_func, model, optimizer, lr_scheduler,
          train_data_iterator, val_data_iterator, timers, args, writer):
    """Train the model function."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:

        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
                                             lr_scheduler,
                                             args, timers)
        skipped_iters += skipped_iter
        iteration += 1
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                 loss_scale, report_memory_flag, writer, args, timers):

    # Update losses.
    for key in loss_dict:
        total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]

    # Logging.
        if args.DDP_impl == 'torch':
            timers_to_log = ['forward', 'backward', 'optimizer',
                             'batch generator']
        else:
            timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
                             'batch generator']

        learning_rate = optimizer.param_groups[0]['lr']

    timers_to_log = []
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward')
    add_to_logging('backward')
    add_to_logging('allreduce')
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
            for key in total_loss_dict:
                writer.add_scalar(key, total_loss_dict[key], iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
                writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
@@ -369,14 +354,50 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
            log_string += ' {}: {:.6E} |'.format(key, avg)
            total_loss_dict[key] = 0.0
        if args.fp16:
                log_string += ' loss scale: {:.1f} |'.format(
                    optimizer.loss_scale)
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


def train(forward_step_func, model, optimizer, lr_scheduler,
          train_data_iterator, val_data_iterator, timers, args, writer):
    """Train the model function."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:

        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
                                             lr_scheduler,
                                             args, timers)
        skipped_iters += skipped_iter
        iteration += 1

        # Logging.
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
                                          iteration, optimizer.loss_scale,
                                          report_memory_flag, writer, args,
                                          timers)

        # Autoresume
        if (iteration % args.adlr_autoresume_interval == 0) and \
           args.adlr_autoresume:
+11 −1
Original line number Diff line number Diff line
@@ -31,9 +31,19 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization


def reduce_losses(losses):
    reduced_losses = torch.cat(
        [loss.clone().detach().view(1) for loss in losses])
    torch.distributed.all_reduce(reduced_losses)
    reduced_losses = reduced_losses / torch.distributed.get_world_size()

    return reduced_losses


def get_tensorboard_writer(args):
    writer = None
    if args.tensorboard_dir and args.rank == 0:
    if hasattr(args, 'tensorboard_dir') and \
       args.tensorboard_dir and args.rank == 0:
        try:
            from torch.utils.tensorboard import SummaryWriter
            writer = SummaryWriter(log_dir=args.tensorboard_dir)
+3 −7
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ from configure_data import configure_data
from megatron import mpu
from megatron.model import BertModel
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run

@@ -99,14 +100,9 @@ def forward_step(data_iterator, model, args, timers):

    loss = lm_loss + nsp_loss

    reduced_losses = torch.cat((lm_loss.clone().detach().view(1),
                                nsp_loss.clone().detach().view(1)))
    torch.distributed.all_reduce(reduced_losses)
    reduced_losses = reduced_losses / torch.distributed.get_world_size()
    lm_loss_reduced = reduced_losses[0]
    nsp_loss_reduced = reduced_losses[1]
    reduced_losses = reduce_losses([lm_loss, nsp_loss])

    return loss, {'lm loss': lm_loss_reduced, 'nsp loss': nsp_loss_reduced}
    return loss, {'lm loss': reduced_losses[0], 'nsp loss': reduced_losses[1]}


def get_train_val_test_data(args):
+3 −4
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders
from megatron import mpu
from megatron.model import GPT2Model
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run

@@ -155,11 +156,9 @@ def forward_step(data_iterator, model, args, timers):
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

    # Reduce loss for logging.
    reduced_loss = loss.clone().detach().view(1)
    torch.distributed.all_reduce(reduced_loss)
    reduced_loss = reduced_loss / torch.distributed.get_world_size()
    reduced_loss = reduce_losses([loss])

    return loss, {'lm loss': reduced_loss}
    return loss, {'lm loss': reduced_loss[0]}


def get_train_val_test_data(args):