Commit 91e11344 authored by Yin, Junqi's avatar Yin, Junqi

fixes lstm

parent 3d4e6ac7
......@@ -40,7 +40,7 @@ def train_and_validate(
# init the hidden state.
_hidden = (
model.module.init_hidden(conf.batch_size)
if "DataParallel" == model.__class__.__name__
if "DataParallel" in model.__class__.__name__
else model.init_hidden(conf.batch_size)
)
......@@ -52,7 +52,7 @@ def train_and_validate(
# repackage the hidden.
_hidden = (
model.module.repackage_hidden(_hidden)
if "DataParallel" == model.__class__.__name__
if "DataParallel" in model.__class__.__name__
else model.repackage_hidden(_hidden)
)
......@@ -94,7 +94,11 @@ def train_and_validate(
with timer("sync_complete", epoch=scheduler.epoch_):
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(model.parameters(), conf.rnn_clip)
n_bits_to_transmit = optimizer.step(timer=timer)
if not conf.ddp:
n_bits_to_transmit = optimizer.step(timer=timer)
else:
optimizer.step()
n_bits_to_transmit = np.nan
# display the logging info.
display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit)
......@@ -145,7 +149,7 @@ def train_and_validate(
conf.logger.log("eval the averaged model on full training data.")
copied_model = copy.deepcopy(
model.module
if "DataParallel" == model.__class__.__name__
if "DataParallel" in model.__class__.__name__
else model
)
optimizer.world_aggregator.agg_model(copied_model, op="avg")
......@@ -232,7 +236,7 @@ def validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader,
batch_size = conf.batch_size
_hidden = (
model.module.init_hidden(batch_size)
if "DataParallel" == model.__class__.__name__
if "DataParallel" in model.__class__.__name__
else model.init_hidden(batch_size)
)
......@@ -243,7 +247,7 @@ def validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader,
# repackage the hidden.
_hidden = (
model.module.repackage_hidden(_hidden)
if "DataParallel" == model.__class__.__name__
if "DataParallel" in model.__class__.__name__
else model.repackage_hidden(_hidden)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment