Loading megatron/mpu/random.py +3 −1 Original line number Diff line number Diff line Loading @@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function): if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) return (None,) + tuple(inp.grad for inp in detached_inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None,) + grads def checkpoint(function, *args): Loading Loading
megatron/mpu/random.py +3 −1 Original line number Diff line number Diff line Loading @@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function): if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) return (None,) + tuple(inp.grad for inp in detached_inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) return (None,) + grads def checkpoint(function, *args): Loading