Commit 6e21c2fb authored by Yin, Junqi's avatar Yin, Junqi

add before_sync print for ring/torus

parent dc292b93
......@@ -98,9 +98,15 @@ class SGD(Optimizer):
n_bits = get_n_bits(flatten_grads.buffer)
else:
with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_):
utils.apply_gradient(
self.param_groups, self.state, apply_grad_to_model=True
)
if self.conf.print_grad:
utils.apply_gradient(
self.param_groups, self.state, apply_grad_to_model=True,
dryrun = False, logger=self.conf.logger
)
else:
utils.apply_gradient(
self.param_groups, self.state, apply_grad_to_model=True
)
with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_):
# first get and flatten all params.
......
......@@ -51,12 +51,14 @@ def apply_gradient(param_groups, state, apply_grad_to_model=True,
else:
d_p = buf
if apply_grad_to_model:
if dryrun:
name = group["name"]
if dryrun: # complete graph
pnew = torch.add(p.data, d_p, alpha=-group["lr"])
name = group["name"]
logger.log(f"parameter: {name} param_norm: {pnew.norm()}")
else:
p.data.add_(d_p, alpha=-group["lr"])
if logger is not None: # ring or torus
logger.log(f"parameter: {name} param_norm: {p.norm()}")
else:
p.grad.data = d_p
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment