Commit 4ee0537a authored by Devendra Singh Sachan's avatar Devendra Singh Sachan
Browse files

small fix for CheckpointFunction's backward() method when sone args may be NoneType

parent 46a536cc
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function):
        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        return (None,) + tuple(inp.grad for inp in detached_inputs)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
        return (None,) + grads


def checkpoint(function, *args):