Commit 9436bdc4 authored by Yin, Junqi's avatar Yin, Junqi

print param after sync

parent ded55b5e
......@@ -57,7 +57,7 @@ def train_and_validate(
with timer("backward_pass", epoch=scheduler.epoch_):
loss.backward()
if conf.print_grad:
print_grad_norm(conf, model, scheduler)
print_grad_norm(conf, model, scheduler, tag='backward_pass')
with timer("sync_complete", epoch=scheduler.epoch_):
if not conf.ddp:
......@@ -65,6 +65,9 @@ def train_and_validate(
else:
optimizer.step()
n_bits_to_transmit = np.nan
if conf.print_grad:
print_grad_norm(conf, model, scheduler, tag='sync_complete')
# display the logging info.
display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit)
......
......@@ -93,7 +93,7 @@ def train_and_validate(
loss.backward()
print(conf.graph.rank, "finish backward", idx)
if conf.print_grad:
print_grad_norm(conf, model, scheduler)
print_grad_norm(conf, model, scheduler, tag='backward_pass')
with timer("sync_complete", epoch=scheduler.epoch_):
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
......@@ -103,7 +103,8 @@ def train_and_validate(
else:
optimizer.step()
n_bits_to_transmit = np.nan
if conf.print_grad:
print_grad_norm(conf, model, scheduler, tag='sync_complete')
# display the logging info.
display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit)
......
......@@ -154,10 +154,13 @@ def dispaly_best_test_stat(conf, scheduler):
)
)
def print_grad_norm(conf, model, scheduler):
conf.logger.log(f"epoch: {scheduler.epoch_} step: {scheduler.local_index}")
def print_grad_norm(conf, model, scheduler, tag='backward_pass'):
conf.logger.log(f"epoch: {scheduler.epoch_} step: {scheduler.local_index} tag: {tag}")
for name, param in model.named_parameters():
if param.requires_grad:
conf.logger.log(f"parameter: {name} param_norm: {param.data.norm().item()} grad_norm: {param.grad.data.norm().item()}")
if tag == 'backward_pass':
conf.logger.log(f"parameter: {name} param_norm: {param.data.norm().item()} grad_norm: {param.grad.data.norm().item()}")
else:
conf.logger.log(f"parameter: {name} param_norm: {param.data.norm().item()}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment