Loading megatron/schedules.py +13 −0 Original line number Diff line number Diff line Loading @@ -43,6 +43,12 @@ def get_forward_backward_func(): return forward_backward_func def free_output_tensor(output_tensors): '''Pseudo-free (i.e., set to scalar) the output tensor's '.data' field. This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its '.grad_fn' field, and not its '.data'. ''' if output_tensors is None: return if isinstance(output_tensors, torch.Tensor): Loading @@ -51,6 +57,13 @@ def free_output_tensor(output_tensors): output_tensor.data = torch.FloatTensor([0]).to(output_tensor.data) def custom_backward(output, grad_output): '''Directly call C++ autograd engine. To make the 'free_output_tensor' (above) optimization work, the C++ autograd engine must be called directly, bypassing Pytorch's torch.autograd.backward. Pytorch's 'backward' checks that the output and grad have the same shape, while C++'s 'backward' does not. ''' assert output.numel() == 1, \ "output should be pseudo-'freed' in schedule, to optimize memory" Loading Loading
megatron/schedules.py +13 −0 Original line number Diff line number Diff line Loading @@ -43,6 +43,12 @@ def get_forward_backward_func(): return forward_backward_func def free_output_tensor(output_tensors): '''Pseudo-free (i.e., set to scalar) the output tensor's '.data' field. This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its '.grad_fn' field, and not its '.data'. ''' if output_tensors is None: return if isinstance(output_tensors, torch.Tensor): Loading @@ -51,6 +57,13 @@ def free_output_tensor(output_tensors): output_tensor.data = torch.FloatTensor([0]).to(output_tensor.data) def custom_backward(output, grad_output): '''Directly call C++ autograd engine. To make the 'free_output_tensor' (above) optimization work, the C++ autograd engine must be called directly, bypassing Pytorch's torch.autograd.backward. Pytorch's 'backward' checks that the output and grad have the same shape, while C++'s 'backward' does not. ''' assert output.numel() == 1, \ "output should be pseudo-'freed' in schedule, to optimize memory" Loading