Loading megatron/model/transformer.py +8 −10 Original line number Diff line number Diff line Loading @@ -27,6 +27,9 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu # >>> from megatron.mpu.random import make_viewless_tensor # <<< """ We use the following notation throughout this file: h: hidden size Loading Loading @@ -696,19 +699,14 @@ class ParallelTransformer(MegatronModule): # See set_input_tensor() hidden_states = self.input_tensor # >>> def make_standalone_tensor(a): assert a._base is not None b = torch.empty((1,), dtype = a.dtype, device = a.device) b.data = a.data return b # <<< # hidden_states = make_standalone_tensor(hidden_states) hidden_states = hidden_states.clone() # hidden_states = MakeStandaloneTensor.apply(hidden_states) # hidden_states = MakeViewlessTensor.apply(hidden_states) hidden_states = make_viewless_tensor(hidden_states) # hidden_states = hidden_states.clone() # >>> # from lutil import pax # pax({"hidden_states": hidden_states}) # pax(0, {"hidden_states": hidden_states}) # <<< if encoder_output is not None: Loading megatron/mpu/random.py +57 −6 Original line number Diff line number Diff line Loading @@ -98,13 +98,54 @@ def gather_split_1d_tensor(tensor): group=get_tensor_model_parallel_group()) return gathered def safely_set_tensor_data_attr(tensor, new_data_tensor): # >>> # from lutil import pax # def make_standalone_tensor(a): # assert a._base is not None # b = torch.empty((1,), dtype = a.dtype, device = a.device) # b.data = a.data # return b # class MakeStandaloneTensor(torch.autograd.Function): class MakeViewlessTensor_(torch.autograd.Function): @staticmethod def forward(ctx, inp): assert inp._base is not None out = torch.empty((1,), dtype = inp.dtype, device = inp.device) out.data = inp.data # pax(0, {"inp": inp, "out": out}) return out @staticmethod def backward(ctx, grad_output): # pax(0, {"grad_output": grad_output}) return grad_output def make_viewless_tensor(tensor): if tensor._base is None: return tensor else: return MakeViewlessTensor_.apply(tensor) def assert_viewless_tensor(tensor): if isinstance(tensor, list): [ assert_viewless_tensor(t) for t in tensor ] return # assert isinstance(tensor, torch.Tensor), \ # "expected Tensor; found %s." % type(tensor).__name__ if not isinstance(tensor, torch.Tensor): return assert tensor._base is None, ( "Ensure tensor._base is None before setting tensor.data. Otherwise, " "a memory leak will occur (and likely accumulate over iterations). " "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." "Ensure tensor._base is None before setting tensor.data or storing " "tensor to memory buffer. Otherwise, a memory leak will occur (and " "likely accumulate over iterations). FYI, tensor._base has shape " "%s, and new_data_tensor has shape %s." ) % (tensor._base.shape, new_data_tensor.shape) # def set_viewless_tensor_data_attr(tensor, new_data_tensor): def safely_set_tensor_data_attr(tensor, new_data_tensor): assert_viewless_tensor(tensor) tensor.data = new_data_tensor # <<< class CudaRNGStatesTracker: """Tracker for the cuda RNG states. Loading Loading @@ -253,11 +294,13 @@ class CheckpointFunction(torch.autograd.Function): # with data_leak_ctx(args[0]): # <<< ctx.input_0_shape = args[0].data.shape # >>> # args[0].data = split_tensor_into_1d_equal_chunks(args[0].data, # new_buffer=True) safely_set_tensor_data_attr( args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) # <<< # Store everything. ctx.save_for_backward(*args) Loading @@ -271,8 +314,16 @@ class CheckpointFunction(torch.autograd.Function): "please use .backward() if possible") inputs = ctx.saved_tensors if ctx.distribute_checkpointed_activations: inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = inputs[0].data.view(ctx.input_0_shape) # >>> # inputs[0].data = gather_split_1d_tensor(inputs[0].data) # inputs[0].data = inputs[0].data.view(ctx.input_0_shape) safely_set_tensor_data_attr( inputs[0], gather_split_1d_tensor(inputs[0].data)) safely_set_tensor_data_attr( inputs[0], inputs[0].data.view(ctx.input_0_shape)) # <<< # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() Loading megatron/p2p_communication.py +5 −0 Original line number Diff line number Diff line Loading @@ -20,6 +20,9 @@ import torch from megatron import get_args from megatron import mpu # >>> from megatron.mpu.random import make_viewless_tensor # <<< def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_shape, Loading Loading @@ -142,10 +145,12 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev = make_viewless_tensor(tensor_recv_prev) if recv_next: tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next = make_viewless_tensor(tensor_recv_next) return tensor_recv_prev, tensor_recv_next Loading megatron/schedules.py +36 −0 Original line number Diff line number Diff line Loading @@ -28,6 +28,10 @@ from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import Float16Module from megatron.model import ModelType # >>> from megatron.mpu.random import assert_viewless_tensor # <<< def get_forward_backward_func(): args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: Loading Loading @@ -306,6 +310,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat model[model_chunk_id], input_tensor, losses_reduced) output_tensors[model_chunk_id].append(output_tensor) assert_viewless_tensor(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: Loading Loading @@ -339,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat mpu.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(tensor_shape, timers=timers)) assert_viewless_tensor(input_tensors[0][-1]) for k in range(num_warmup_microbatches): output_tensor = forward_step_helper(k) Loading Loading @@ -370,6 +376,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat tensor_shape=tensor_shape, timers=timers) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) assert_viewless_tensor(output_tensor_grad) else: input_tensor = \ p2p_communication.send_forward_recv_forward( Loading @@ -378,6 +385,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat timers=timers) free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) input_tensors[next_forward_model_chunk_id].append(input_tensor) assert_viewless_tensor(input_tensor) # Run 1F1B in steady state. for k in range(num_microbatches_remaining): Loading Loading @@ -447,15 +455,18 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat # right location. if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) assert_viewless_tensor(input_tensor) if recv_next: output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) assert_viewless_tensor(output_tensor_grad) # Run cooldown backward passes (flush out pipeline). if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( p2p_communication.recv_backward(tensor_shape, timers=timers)) assert_viewless_tensor(output_tensor_grads[num_model_chunks-1][-1]) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) Loading @@ -470,6 +481,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, timers=timers)) assert_viewless_tensor(output_tensor_grads[next_backward_model_chunk_id][-1]) return losses_reduced Loading Loading @@ -508,6 +520,7 @@ def recv_forward(tensor_shapes, timers): else: input_tensors.append(p2p_communication.recv_forward(tensor_shape, timers=timers)) assert_viewless_tensor(input_tensors[-1]) return input_tensors Loading @@ -519,6 +532,7 @@ def recv_backward(tensor_shapes, timers): else: output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, timers=timers)) assert_viewless_tensor(output_tensor_grads[-1]) return output_tensor_grads Loading Loading @@ -551,6 +565,7 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers): output_tensor_grad = p2p_communication.send_forward_recv_backward( output_tensor, tensor_shape, timers=timers) output_tensor_grads.append(output_tensor_grad) assert_viewless_tensor(output_tensor_grad) return output_tensor_grads Loading @@ -565,6 +580,7 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): input_tensor = p2p_communication.send_backward_recv_forward( input_tensor_grad, tensor_shape, timers=timers) input_tensors.append(input_tensor) assert_viewless_tensor(input_tensor) return input_tensors Loading Loading @@ -615,6 +631,15 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite send_forward(output_tensor, send_tensor_shapes, timers=timers) if not forward_only: # >>> if input_tensor[0] is not None: from lutil import pax pax({ "input_tensor" : input_tensor, }) # <<< assert_viewless_tensor(input_tensor) assert_viewless_tensor(output_tensor) input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) Loading Loading @@ -644,6 +669,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite timers=timers) # Add input_tensor and output_tensor to end of list. # >>> # assert input_tensor[0]._base is None, \ # "rank %s; uh oh." % torch.distributed.get_rank() # if input_tensor[0] is not None: # from lutil import pax # pax(4, { # "input_tensor[0]" : input_tensor[0], # }) # <<< assert_viewless_tensor(input_tensor) assert_viewless_tensor(output_tensor) input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) Loading Loading
megatron/model/transformer.py +8 −10 Original line number Diff line number Diff line Loading @@ -27,6 +27,9 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu # >>> from megatron.mpu.random import make_viewless_tensor # <<< """ We use the following notation throughout this file: h: hidden size Loading Loading @@ -696,19 +699,14 @@ class ParallelTransformer(MegatronModule): # See set_input_tensor() hidden_states = self.input_tensor # >>> def make_standalone_tensor(a): assert a._base is not None b = torch.empty((1,), dtype = a.dtype, device = a.device) b.data = a.data return b # <<< # hidden_states = make_standalone_tensor(hidden_states) hidden_states = hidden_states.clone() # hidden_states = MakeStandaloneTensor.apply(hidden_states) # hidden_states = MakeViewlessTensor.apply(hidden_states) hidden_states = make_viewless_tensor(hidden_states) # hidden_states = hidden_states.clone() # >>> # from lutil import pax # pax({"hidden_states": hidden_states}) # pax(0, {"hidden_states": hidden_states}) # <<< if encoder_output is not None: Loading
megatron/mpu/random.py +57 −6 Original line number Diff line number Diff line Loading @@ -98,13 +98,54 @@ def gather_split_1d_tensor(tensor): group=get_tensor_model_parallel_group()) return gathered def safely_set_tensor_data_attr(tensor, new_data_tensor): # >>> # from lutil import pax # def make_standalone_tensor(a): # assert a._base is not None # b = torch.empty((1,), dtype = a.dtype, device = a.device) # b.data = a.data # return b # class MakeStandaloneTensor(torch.autograd.Function): class MakeViewlessTensor_(torch.autograd.Function): @staticmethod def forward(ctx, inp): assert inp._base is not None out = torch.empty((1,), dtype = inp.dtype, device = inp.device) out.data = inp.data # pax(0, {"inp": inp, "out": out}) return out @staticmethod def backward(ctx, grad_output): # pax(0, {"grad_output": grad_output}) return grad_output def make_viewless_tensor(tensor): if tensor._base is None: return tensor else: return MakeViewlessTensor_.apply(tensor) def assert_viewless_tensor(tensor): if isinstance(tensor, list): [ assert_viewless_tensor(t) for t in tensor ] return # assert isinstance(tensor, torch.Tensor), \ # "expected Tensor; found %s." % type(tensor).__name__ if not isinstance(tensor, torch.Tensor): return assert tensor._base is None, ( "Ensure tensor._base is None before setting tensor.data. Otherwise, " "a memory leak will occur (and likely accumulate over iterations). " "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." "Ensure tensor._base is None before setting tensor.data or storing " "tensor to memory buffer. Otherwise, a memory leak will occur (and " "likely accumulate over iterations). FYI, tensor._base has shape " "%s, and new_data_tensor has shape %s." ) % (tensor._base.shape, new_data_tensor.shape) # def set_viewless_tensor_data_attr(tensor, new_data_tensor): def safely_set_tensor_data_attr(tensor, new_data_tensor): assert_viewless_tensor(tensor) tensor.data = new_data_tensor # <<< class CudaRNGStatesTracker: """Tracker for the cuda RNG states. Loading Loading @@ -253,11 +294,13 @@ class CheckpointFunction(torch.autograd.Function): # with data_leak_ctx(args[0]): # <<< ctx.input_0_shape = args[0].data.shape # >>> # args[0].data = split_tensor_into_1d_equal_chunks(args[0].data, # new_buffer=True) safely_set_tensor_data_attr( args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) # <<< # Store everything. ctx.save_for_backward(*args) Loading @@ -271,8 +314,16 @@ class CheckpointFunction(torch.autograd.Function): "please use .backward() if possible") inputs = ctx.saved_tensors if ctx.distribute_checkpointed_activations: inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = inputs[0].data.view(ctx.input_0_shape) # >>> # inputs[0].data = gather_split_1d_tensor(inputs[0].data) # inputs[0].data = inputs[0].data.view(ctx.input_0_shape) safely_set_tensor_data_attr( inputs[0], gather_split_1d_tensor(inputs[0].data)) safely_set_tensor_data_attr( inputs[0], inputs[0].data.view(ctx.input_0_shape)) # <<< # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() Loading
megatron/p2p_communication.py +5 −0 Original line number Diff line number Diff line Loading @@ -20,6 +20,9 @@ import torch from megatron import get_args from megatron import mpu # >>> from megatron.mpu.random import make_viewless_tensor # <<< def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, tensor_shape, Loading Loading @@ -142,10 +145,12 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, if recv_prev: tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev = make_viewless_tensor(tensor_recv_prev) if recv_next: tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next = make_viewless_tensor(tensor_recv_next) return tensor_recv_prev, tensor_recv_next Loading
megatron/schedules.py +36 −0 Original line number Diff line number Diff line Loading @@ -28,6 +28,10 @@ from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import Float16Module from megatron.model import ModelType # >>> from megatron.mpu.random import assert_viewless_tensor # <<< def get_forward_backward_func(): args = get_args() if mpu.get_pipeline_model_parallel_world_size() > 1: Loading Loading @@ -306,6 +310,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat model[model_chunk_id], input_tensor, losses_reduced) output_tensors[model_chunk_id].append(output_tensor) assert_viewless_tensor(output_tensor) # if forward-only, no need to save tensors for a backward pass if forward_only: Loading Loading @@ -339,6 +344,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat mpu.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( p2p_communication.recv_forward(tensor_shape, timers=timers)) assert_viewless_tensor(input_tensors[0][-1]) for k in range(num_warmup_microbatches): output_tensor = forward_step_helper(k) Loading Loading @@ -370,6 +376,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat tensor_shape=tensor_shape, timers=timers) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) assert_viewless_tensor(output_tensor_grad) else: input_tensor = \ p2p_communication.send_forward_recv_forward( Loading @@ -378,6 +385,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat timers=timers) free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) input_tensors[next_forward_model_chunk_id].append(input_tensor) assert_viewless_tensor(input_tensor) # Run 1F1B in steady state. for k in range(num_microbatches_remaining): Loading Loading @@ -447,15 +455,18 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat # right location. if recv_prev: input_tensors[next_forward_model_chunk_id].append(input_tensor) assert_viewless_tensor(input_tensor) if recv_next: output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) assert_viewless_tensor(output_tensor_grad) # Run cooldown backward passes (flush out pipeline). if not forward_only: if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( p2p_communication.recv_backward(tensor_shape, timers=timers)) assert_viewless_tensor(output_tensor_grads[num_model_chunks-1][-1]) for k in range(num_microbatches_remaining, num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) Loading @@ -470,6 +481,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, timers=timers)) assert_viewless_tensor(output_tensor_grads[next_backward_model_chunk_id][-1]) return losses_reduced Loading Loading @@ -508,6 +520,7 @@ def recv_forward(tensor_shapes, timers): else: input_tensors.append(p2p_communication.recv_forward(tensor_shape, timers=timers)) assert_viewless_tensor(input_tensors[-1]) return input_tensors Loading @@ -519,6 +532,7 @@ def recv_backward(tensor_shapes, timers): else: output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, timers=timers)) assert_viewless_tensor(output_tensor_grads[-1]) return output_tensor_grads Loading Loading @@ -551,6 +565,7 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, timers): output_tensor_grad = p2p_communication.send_forward_recv_backward( output_tensor, tensor_shape, timers=timers) output_tensor_grads.append(output_tensor_grad) assert_viewless_tensor(output_tensor_grad) return output_tensor_grads Loading @@ -565,6 +580,7 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): input_tensor = p2p_communication.send_backward_recv_forward( input_tensor_grad, tensor_shape, timers=timers) input_tensors.append(input_tensor) assert_viewless_tensor(input_tensor) return input_tensors Loading Loading @@ -615,6 +631,15 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite send_forward(output_tensor, send_tensor_shapes, timers=timers) if not forward_only: # >>> if input_tensor[0] is not None: from lutil import pax pax({ "input_tensor" : input_tensor, }) # <<< assert_viewless_tensor(input_tensor) assert_viewless_tensor(output_tensor) input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) Loading Loading @@ -644,6 +669,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite timers=timers) # Add input_tensor and output_tensor to end of list. # >>> # assert input_tensor[0]._base is None, \ # "rank %s; uh oh." % torch.distributed.get_rank() # if input_tensor[0] is not None: # from lutil import pax # pax(4, { # "input_tensor[0]" : input_tensor[0], # }) # <<< assert_viewless_tensor(input_tensor) assert_viewless_tensor(output_tensor) input_tensors.append(input_tensor) output_tensors.append(output_tensor) free_output_tensor(output_tensor, args.deallocate_pipeline_outputs) Loading