Loading megatron/model/transformer.py +6 −9 Original line number Diff line number Diff line Loading @@ -699,15 +699,12 @@ class ParallelTransformer(MegatronModule): # See set_input_tensor() hidden_states = self.input_tensor # hidden_states = make_standalone_tensor(hidden_states) # hidden_states = MakeStandaloneTensor.apply(hidden_states) # hidden_states = MakeViewlessTensor.apply(hidden_states) hidden_states = make_viewless_tensor(hidden_states) # hidden_states = hidden_states.clone() # >>> # from lutil import pax # pax(0, {"hidden_states": hidden_states}) # <<< # Viewless tensor hidden_states = make_viewless_tensor( hidden_states, requires_grad = True, keep_graph = True, ) if encoder_output is not None: encoder_output = encoder_output.transpose(0, 1).contiguous() Loading megatron/mpu/random.py +58 −19 Original line number Diff line number Diff line Loading @@ -99,7 +99,7 @@ def gather_split_1d_tensor(tensor): return gathered # >>> # from lutil import pax from lutil import pax # **************** # def make_standalone_tensor(a): # assert a._base is not None Loading @@ -107,26 +107,66 @@ def gather_split_1d_tensor(tensor): # b.data = a.data # return b # class MakeStandaloneTensor(torch.autograd.Function): class MakeViewlessTensor_(torch.autograd.Function): # class MakeViewlessTensor_(torch.autograd.Function): class MakeViewlessTensor(torch.autograd.Function): # @staticmethod # def forward(ctx, inp): # assert inp._base is not None # out = torch.empty((1,), dtype = inp.dtype, device = inp.device) # out.data = inp.data # # pax(0, {"inp": inp, "out": out}) # return out @staticmethod def forward(ctx, inp): assert inp._base is not None out = torch.empty((1,), dtype = inp.dtype, device = inp.device) out.data = inp.data # pax(0, {"inp": inp, "out": out}) return out def forward(ctx, inp, requires_grad): return _kernel_make_viewless_tensor(inp, requires_grad) # @staticmethod # def forward(ctx, args): # return [_kernel_make_viewless_tensor(*args)] @staticmethod def backward(ctx, grad_output): # pax(0, {"grad_output": grad_output}) return grad_output # return grad_output return grad_output, None def _kernel_make_viewless_tensor(inp, requires_grad): out = torch.empty( (1,), dtype = inp.dtype, device = inp.device, requires_grad = requires_grad, ) out.data = inp.data # >>> # pax(0, {"inp": inp, "out": out}) # assert out.requires_grad # <<< return out def make_viewless_tensor(tensor): if tensor._base is None: return tensor # def make_viewless_tensor(tensor): # if tensor._base is None: # return tensor # else: # return MakeViewlessTensor_.apply(tensor) def make_viewless_tensor(inp, requires_grad, keep_graph): # return tensor as-is, if not a 'view' if inp._base is None: return inp # create viewless tensor if keep_graph: # return MakeViewlessTensor.apply((inp, requires_grad))[0] return MakeViewlessTensor.apply(inp, requires_grad) else: return MakeViewlessTensor_.apply(tensor) return _kernel_make_viewless_tensor(inp, requires_grad) # return MakeViewlessTensor.apply((inp, requires_grad))[0] # return MakeViewlessTensor.apply(inp, requires_grad) # return MakeViewlessTensor.apply(inp) # return MakeViewlessTensor.apply(inp, 7) # return MakeViewlessTensor.apply(inp, 7)[0] def assert_viewless_tensor(tensor): def assert_viewless_tensor(tensor, extra_msg = None): if isinstance(tensor, list): [ assert_viewless_tensor(t) for t in tensor ] return Loading @@ -137,13 +177,12 @@ def assert_viewless_tensor(tensor): assert tensor._base is None, ( "Ensure tensor._base is None before setting tensor.data or storing " "tensor to memory buffer. Otherwise, a memory leak will occur (and " "likely accumulate over iterations). FYI, tensor._base has shape " "%s, and new_data_tensor has shape %s." ) % (tensor._base.shape, new_data_tensor.shape) "likely accumulate over iterations). %s" ) % extra_msg # def set_viewless_tensor_data_attr(tensor, new_data_tensor): def safely_set_tensor_data_attr(tensor, new_data_tensor): assert_viewless_tensor(tensor) assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) tensor.data = new_data_tensor # <<< Loading megatron/p2p_communication.py +6 −2 Original line number Diff line number Diff line Loading @@ -145,12 +145,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev = make_viewless_tensor(tensor_recv_prev) tensor_recv_prev = make_viewless_tensor(tensor_recv_prev, requires_grad = True, keep_graph = False) if recv_next: tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next = make_viewless_tensor(tensor_recv_next) tensor_recv_next = make_viewless_tensor(tensor_recv_next, requires_grad = True, keep_graph = False) return tensor_recv_prev, tensor_recv_next Loading megatron/schedules.py +0 −16 Original line number Diff line number Diff line Loading @@ -631,13 +631,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite send_forward(output_tensor, send_tensor_shapes, timers=timers) if not forward_only: # >>> # if input_tensor[0] is not None: # from lutil import pax # pax({ # "input_tensor" : input_tensor, # }) # <<< assert_viewless_tensor(input_tensor) assert_viewless_tensor(output_tensor) input_tensors.append(input_tensor) Loading Loading @@ -669,15 +662,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite timers=timers) # Add input_tensor and output_tensor to end of list. # >>> # assert input_tensor[0]._base is None, \ # "rank %s; uh oh." % torch.distributed.get_rank() # if input_tensor[0] is not None: # from lutil import pax # pax(4, { # "input_tensor[0]" : input_tensor[0], # }) # <<< assert_viewless_tensor(input_tensor) assert_viewless_tensor(output_tensor) input_tensors.append(input_tensor) Loading Loading
megatron/model/transformer.py +6 −9 Original line number Diff line number Diff line Loading @@ -699,15 +699,12 @@ class ParallelTransformer(MegatronModule): # See set_input_tensor() hidden_states = self.input_tensor # hidden_states = make_standalone_tensor(hidden_states) # hidden_states = MakeStandaloneTensor.apply(hidden_states) # hidden_states = MakeViewlessTensor.apply(hidden_states) hidden_states = make_viewless_tensor(hidden_states) # hidden_states = hidden_states.clone() # >>> # from lutil import pax # pax(0, {"hidden_states": hidden_states}) # <<< # Viewless tensor hidden_states = make_viewless_tensor( hidden_states, requires_grad = True, keep_graph = True, ) if encoder_output is not None: encoder_output = encoder_output.transpose(0, 1).contiguous() Loading
megatron/mpu/random.py +58 −19 Original line number Diff line number Diff line Loading @@ -99,7 +99,7 @@ def gather_split_1d_tensor(tensor): return gathered # >>> # from lutil import pax from lutil import pax # **************** # def make_standalone_tensor(a): # assert a._base is not None Loading @@ -107,26 +107,66 @@ def gather_split_1d_tensor(tensor): # b.data = a.data # return b # class MakeStandaloneTensor(torch.autograd.Function): class MakeViewlessTensor_(torch.autograd.Function): # class MakeViewlessTensor_(torch.autograd.Function): class MakeViewlessTensor(torch.autograd.Function): # @staticmethod # def forward(ctx, inp): # assert inp._base is not None # out = torch.empty((1,), dtype = inp.dtype, device = inp.device) # out.data = inp.data # # pax(0, {"inp": inp, "out": out}) # return out @staticmethod def forward(ctx, inp): assert inp._base is not None out = torch.empty((1,), dtype = inp.dtype, device = inp.device) out.data = inp.data # pax(0, {"inp": inp, "out": out}) return out def forward(ctx, inp, requires_grad): return _kernel_make_viewless_tensor(inp, requires_grad) # @staticmethod # def forward(ctx, args): # return [_kernel_make_viewless_tensor(*args)] @staticmethod def backward(ctx, grad_output): # pax(0, {"grad_output": grad_output}) return grad_output # return grad_output return grad_output, None def _kernel_make_viewless_tensor(inp, requires_grad): out = torch.empty( (1,), dtype = inp.dtype, device = inp.device, requires_grad = requires_grad, ) out.data = inp.data # >>> # pax(0, {"inp": inp, "out": out}) # assert out.requires_grad # <<< return out def make_viewless_tensor(tensor): if tensor._base is None: return tensor # def make_viewless_tensor(tensor): # if tensor._base is None: # return tensor # else: # return MakeViewlessTensor_.apply(tensor) def make_viewless_tensor(inp, requires_grad, keep_graph): # return tensor as-is, if not a 'view' if inp._base is None: return inp # create viewless tensor if keep_graph: # return MakeViewlessTensor.apply((inp, requires_grad))[0] return MakeViewlessTensor.apply(inp, requires_grad) else: return MakeViewlessTensor_.apply(tensor) return _kernel_make_viewless_tensor(inp, requires_grad) # return MakeViewlessTensor.apply((inp, requires_grad))[0] # return MakeViewlessTensor.apply(inp, requires_grad) # return MakeViewlessTensor.apply(inp) # return MakeViewlessTensor.apply(inp, 7) # return MakeViewlessTensor.apply(inp, 7)[0] def assert_viewless_tensor(tensor): def assert_viewless_tensor(tensor, extra_msg = None): if isinstance(tensor, list): [ assert_viewless_tensor(t) for t in tensor ] return Loading @@ -137,13 +177,12 @@ def assert_viewless_tensor(tensor): assert tensor._base is None, ( "Ensure tensor._base is None before setting tensor.data or storing " "tensor to memory buffer. Otherwise, a memory leak will occur (and " "likely accumulate over iterations). FYI, tensor._base has shape " "%s, and new_data_tensor has shape %s." ) % (tensor._base.shape, new_data_tensor.shape) "likely accumulate over iterations). %s" ) % extra_msg # def set_viewless_tensor_data_attr(tensor, new_data_tensor): def safely_set_tensor_data_attr(tensor, new_data_tensor): assert_viewless_tensor(tensor) assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) tensor.data = new_data_tensor # <<< Loading
megatron/p2p_communication.py +6 −2 Original line number Diff line number Diff line Loading @@ -145,12 +145,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev = make_viewless_tensor(tensor_recv_prev) tensor_recv_prev = make_viewless_tensor(tensor_recv_prev, requires_grad = True, keep_graph = False) if recv_next: tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next = make_viewless_tensor(tensor_recv_next) tensor_recv_next = make_viewless_tensor(tensor_recv_next, requires_grad = True, keep_graph = False) return tensor_recv_prev, tensor_recv_next Loading
megatron/schedules.py +0 −16 Original line number Diff line number Diff line Loading @@ -631,13 +631,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite send_forward(output_tensor, send_tensor_shapes, timers=timers) if not forward_only: # >>> # if input_tensor[0] is not None: # from lutil import pax # pax({ # "input_tensor" : input_tensor, # }) # <<< assert_viewless_tensor(input_tensor) assert_viewless_tensor(output_tensor) input_tensors.append(input_tensor) Loading Loading @@ -669,15 +662,6 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite timers=timers) # Add input_tensor and output_tensor to end of list. # >>> # assert input_tensor[0]._base is None, \ # "rank %s; uh oh." % torch.distributed.get_rank() # if input_tensor[0] is not None: # from lutil import pax # pax(4, { # "input_tensor[0]" : input_tensor[0], # }) # <<< assert_viewless_tensor(input_tensor) assert_viewless_tensor(output_tensor) input_tensors.append(input_tensor) Loading