Commit cc691cbf authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Move unwrap to megatron/utils.py and clean up imports in megatron/schedules.py

parent e3e5ea89
Loading
Loading
Loading
Loading
+5 −16
Original line number Diff line number Diff line
@@ -21,12 +21,12 @@ import sys
import numpy as np

import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP

from megatron import (get_args,
                      mpu,
                      print_rank_0,
                      update_num_microbatches)
                      update_num_microbatches,
                      utils)

_CHECKPOINT_VERSION = None

@@ -111,12 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    unwrapped_model = []
    for model_module in model:
        if isinstance(model_module, torchDDP):
            model_module = model_module.module
        unwrapped_model.append(model_module)
    model = unwrapped_model
    model = utils.unwrap_model(model)

    print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
        iteration, args.save))
@@ -220,12 +215,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
    args = get_args()
    load_dir = getattr(args, load_arg)

    unwrapped_model = []
    for model_module in model:
        if isinstance(model_module, torchDDP):
            model_module = model_module.module
        unwrapped_model.append(model_module)
    model = unwrapped_model
    model = utils.unwrap_model(model)

    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)
@@ -389,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f

    args = get_args()

    if isinstance(model, torchDDP):
        model = model.module
    model = utils.unwrap_model(model)

    load_path = args.load if from_realm_chkpt else args.ict_load

+22 −22
Original line number Diff line number Diff line
@@ -16,14 +16,10 @@
import torch

from megatron import get_args
from megatron import get_num_microbatches
from megatron import get_timers
from megatron import mpu
from megatron import get_num_microbatches
from megatron.p2p_communication import recv_forward, recv_backward
from megatron.p2p_communication import send_forward, send_backward
from megatron.p2p_communication import send_forward_recv_backward, send_backward_recv_forward
from megatron.p2p_communication import send_forward_recv_forward, send_backward_recv_backward
from megatron.p2p_communication import send_forward_backward_recv_forward_backward
from megatron import p2p_communication


def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
@@ -154,7 +150,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat

    # Run warmup forward passes.
    mpu.set_virtual_pipeline_model_parallel_rank(0)
    input_tensors[0].append(recv_forward(timers, use_ring_exchange=True))
    input_tensors[0].append(p2p_communication.recv_forward(timers, use_ring_exchange=True))
    for k in range(num_warmup_microbatches):
        output_tensor = forward_step_helper(k)
        next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
@@ -173,13 +169,14 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                recv_next = False
            input_tensor, output_tensor_grad = \
                send_forward_backward_recv_forward_backward(
                p2p_communication.send_forward_backward_recv_forward_backward(
                        output_tensor, input_tensor_grad,
                        recv_prev=recv_prev, recv_next=recv_next,
                        timers=timers)
            output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
        else:
            input_tensor = send_forward_recv_forward(output_tensor, recv_prev, timers)
            input_tensor = \
                p2p_communication.send_forward_recv_forward(output_tensor, recv_prev, timers)
        input_tensors[next_forward_model_chunk_id].append(input_tensor)

    # Run 1F1B in steady state.
@@ -238,7 +235,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat

        # Communicate tensors.
        input_tensor, output_tensor_grad = \
            send_forward_backward_recv_forward_backward(
            p2p_communication.send_forward_backward_recv_forward_backward(
                    output_tensor, input_tensor_grad,
                    recv_prev=recv_prev, recv_next=recv_next,
                    timers=timers)
@@ -253,7 +250,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
    if not forward_only:
        if all_warmup_microbatches:
            output_tensor_grads[num_model_chunks-1].append(
                recv_backward(timers, use_ring_exchange=True))
                p2p_communication.recv_backward(timers, use_ring_exchange=True))
        for k in range(num_microbatches_remaining, num_microbatches):
            input_tensor_grad = backward_step_helper(k)
            next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
@@ -264,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
            if k == (num_microbatches - 1):
                recv_next = False
            output_tensor_grads[next_backward_model_chunk_id].append(
                send_backward_recv_backward(input_tensor_grad, recv_next, timers))
                p2p_communication.send_backward_recv_backward(
                    input_tensor_grad, recv_next, timers))

    return losses_reduced

@@ -294,7 +292,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
        input_tensor = recv_forward(timers)
        input_tensor = p2p_communication.recv_forward(timers)
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        # Barrier before first receive to measure forward stall.
@@ -302,7 +300,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
            timers('forward-pipeline-stall').start()
            torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
            timers('forward-pipeline-stall').stop()
        send_forward(output_tensor, timers)
        p2p_communication.send_forward(output_tensor, timers)

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)
@@ -317,7 +315,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        input_tensor = recv_forward(timers)
        input_tensor = p2p_communication.recv_forward(timers)

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
@@ -326,9 +324,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
            send_forward(output_tensor, timers)
            p2p_communication.send_forward(output_tensor, timers)
        else:
            output_tensor_grad = send_forward_recv_backward(output_tensor, timers)
            output_tensor_grad = \
                    p2p_communication.send_forward_recv_backward(output_tensor, timers)

        # Add input_tensor and output_tensor to end of list, then pop from the
        # start of the list for backward pass.
@@ -337,7 +336,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,

        if forward_only:
            if not last_iteration:
                input_tensor = recv_forward(timers)
                input_tensor = p2p_communication.recv_forward(timers)
        else:
            input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)

@@ -347,9 +346,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,

            if last_iteration:
                input_tensor = None
                send_backward(input_tensor_grad, timers)
                p2p_communication.send_backward(input_tensor_grad, timers)
            else:
                input_tensor = send_backward_recv_forward(input_tensor_grad, timers)
                input_tensor = \
                        p2p_communication.send_backward_recv_forward(input_tensor_grad, timers)

    # Run cooldown backward passes.
    if not forward_only:
@@ -357,12 +357,12 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            output_tensor_grad = recv_backward(timers)
            output_tensor_grad = p2p_communication.recv_backward(timers)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            send_backward(input_tensor_grad, timers)
            p2p_communication.send_backward(input_tensor_grad, timers)

    return losses_reduced
+5 −5
Original line number Diff line number Diff line
@@ -46,6 +46,7 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
@@ -288,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func):

    model = get_model(model_provider_func)

    unwrapped_model = model
    while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
        unwrapped_model = unwrapped_model.module
    unwrapped_model = unwrap_model(model,
                                   (torchDDP, LocalDDP, FP16Module))
    optimizer = get_megatron_optimizer(unwrapped_model)

    lr_scheduler = get_learning_rate_scheduler(optimizer)
@@ -370,8 +370,8 @@ def train_step(forward_step_func, data_iterator,
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
            unwrapped_model = unwrapped_model.module
        unwrapped_model = unwrap_model(
            unwrapped_model, (torchDDP, LocalDDP, FP16Module))

        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
+18 −1
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
import sys

import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
@@ -26,11 +27,25 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate


def unwrap_model(model, module_instances=(torchDDP)):
    return_list = True
    if not isinstance(model, list):
        model = [model]
        return_list = False
    unwrapped_model = []
    for model_module in model:
        while isinstance(model_module, module_instances):
            model_module = model_module.module
        unwrapped_model.append(model_module)
    if not return_list:
        return unwrapped_model[0]
    return unwrapped_model


def calc_params_l2_norm(model):
    """Calculate l2 norm of parameters """
    # Remove duplicate params.
@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model,
                                      optimizer, lr_scheduler):
    """Check for autoresume signal and exit if it is received."""
    from megatron.checkpointing import save_checkpoint

    args = get_args()
    autoresume = get_adlr_autoresume()
    # Add barrier to ensure consistnecy.