Commit 6e21c2fb authored by Yin, Junqi's avatar Yin, Junqi
Browse files

add before_sync print for ring/torus

parent dc292b93
...@@ -98,9 +98,15 @@ class SGD(Optimizer): ...@@ -98,9 +98,15 @@ class SGD(Optimizer):
n_bits = get_n_bits(flatten_grads.buffer) n_bits = get_n_bits(flatten_grads.buffer)
else: else:
with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_): with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_):
utils.apply_gradient( if self.conf.print_grad:
self.param_groups, self.state, apply_grad_to_model=True utils.apply_gradient(
) self.param_groups, self.state, apply_grad_to_model=True,
dryrun = False, logger=self.conf.logger
self.param_groups, self.state, apply_grad_to_model=True
with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_): with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_):
# first get and flatten all params. # first get and flatten all params.
...@@ -51,12 +51,14 @@ def apply_gradient(param_groups, state, apply_grad_to_model=True, ...@@ -51,12 +51,14 @@ def apply_gradient(param_groups, state, apply_grad_to_model=True,
else: else:
d_p = buf d_p = buf
if apply_grad_to_model: if apply_grad_to_model:
if dryrun: name = group["name"]
if dryrun: # complete graph
pnew = torch.add(, d_p, alpha=-group["lr"]) pnew = torch.add(, d_p, alpha=-group["lr"])
name = group["name"]
logger.log(f"parameter: {name} param_norm: {pnew.norm()}") logger.log(f"parameter: {name} param_norm: {pnew.norm()}")
else: else:, alpha=-group["lr"]), alpha=-group["lr"])
if logger is not None: # ring or torus
logger.log(f"parameter: {name} param_norm: {p.norm()}")
else: else: = d_p = d_p
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment