Commit c20f4d48 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'checkpointing-fix' into 'master'

small fix for CheckpointFunction's backward() method when some args may be NoneType

See merge request ADLR/megatron-lm!92
parents 46a536cc 4ee0537a
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function):
        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        return (None,) + tuple(inp.grad for inp in detached_inputs)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
        return (None,) + grads


def checkpoint(function, *args):