Commit 113c6362 authored by Jared Casper's avatar Jared Casper
Browse files

Add temporary assert to finetuning until it can be fixed.

parent 42c1cf42
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ from functools import partial

import torch

from megatron import get_args
from megatron import get_args, get_num_microbatches
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
@@ -154,6 +154,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
    args = get_args()
    timers = get_timers()

    assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"

    # Turn on training mode which enables dropout.
    for m in model:
        m.train()