Commit 7eb85c6d authored by Yin, Junqi's avatar Yin, Junqi

print param before sync

parent 9d3ed392
......@@ -66,6 +66,12 @@ class SGD(Optimizer):
def step(self, closure=None, **kargs):
if self.conf.is_centralized:
if self.conf.print_grad:
utils.apply_gradient(
self.param_groups, self.state, apply_grad_to_model=True,
dryrun = True, logger=self.conf.logger
)
with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_):
# Get data.
grads, _ = comm.get_data(
......
......@@ -10,7 +10,10 @@ import pcode.utils.communication as comm
"""common utilities"""
def apply_gradient(param_groups, state, apply_grad_to_model=True):
def apply_gradient(param_groups, state, apply_grad_to_model=True,
dryrun=False, logger=None):
if logger is not None:
logger.log(f"tag: before_sync")
for group in param_groups:
weight_decay = group["weight_decay"]
momentum = group["momentum"]
......@@ -32,17 +35,28 @@ def apply_gradient(param_groups, state, apply_grad_to_model=True):
# apply the momentum.
if momentum != 0:
if "momentum_buffer" not in param_state:
buf = param_state["momentum_buffer"] = torch.zeros_like(p.data)
if dryrun:
buf = torch.zeros_like(p.data)
else:
buf = param_state["momentum_buffer"] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(d_p)
else:
buf = param_state["momentum_buffer"]
if dryrun:
buf = param_state["momentum_buffer"].detach().clone()
else:
buf = param_state["momentum_buffer"]
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf
if apply_grad_to_model:
p.data.add_(d_p, alpha=-group["lr"])
if dryrun:
pnew = torch.add(p.data, d_p, alpha=-group["lr"])
name = group["name"]
logger.log(f"parameter: {name} param_norm: {pnew.norm()}")
else:
p.data.add_(d_p, alpha=-group["lr"])
else:
p.grad.data = d_p
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment