Loading megatron/inference/forward_step.py +2 −19 Original line number Diff line number Diff line Loading @@ -16,25 +16,10 @@ """Forward step utilities.""" import torch from megatron.p2p_communication import recv_forward, send_forward from .sampling import sample from megatron import mpu import torch.nn.functional as F from megatron import print_rank_0 from megatron import get_args, get_tokenizer from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model from .communication import ( broadcast_float_list, copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage) from .tokenization import tokenize_prompts # These are needed to unwrap the model, would be nice to put these in megatron.utils if possible? from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import Float16Module from megatron import get_args def forward_step(model, tokens, position_ids, attention_mask, Loading @@ -51,9 +36,7 @@ def forward_step(model, tokens, position_ids, attention_mask, input_tensor = recv_forward() # Forward pass through the model. unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) model.set_input_tensor(input_tensor) output_tensor = model( tokens, position_ids, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, Loading megatron/model/module.py +4 −0 Original line number Diff line number Diff line Loading @@ -166,6 +166,10 @@ class Float16Module(MegatronModule): self.float16_convertor = float16_convertor def set_input_tensor(self, input_tensor): return self.module.set_input_tensor(input_tensor) def forward(self, *inputs, **kwargs): if mpu.is_pipeline_first_stage(): inputs = fp32_to_float16(inputs, self.float16_convertor) Loading Loading
megatron/inference/forward_step.py +2 −19 Original line number Diff line number Diff line Loading @@ -16,25 +16,10 @@ """Forward step utilities.""" import torch from megatron.p2p_communication import recv_forward, send_forward from .sampling import sample from megatron import mpu import torch.nn.functional as F from megatron import print_rank_0 from megatron import get_args, get_tokenizer from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model from .communication import ( broadcast_float_list, copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage) from .tokenization import tokenize_prompts # These are needed to unwrap the model, would be nice to put these in megatron.utils if possible? from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import Float16Module from megatron import get_args def forward_step(model, tokens, position_ids, attention_mask, Loading @@ -51,9 +36,7 @@ def forward_step(model, tokens, position_ids, attention_mask, input_tensor = recv_forward() # Forward pass through the model. unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) model.set_input_tensor(input_tensor) output_tensor = model( tokens, position_ids, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, Loading
megatron/model/module.py +4 −0 Original line number Diff line number Diff line Loading @@ -166,6 +166,10 @@ class Float16Module(MegatronModule): self.float16_convertor = float16_convertor def set_input_tensor(self, input_tensor): return self.module.set_input_tensor(input_tensor) def forward(self, *inputs, **kwargs): if mpu.is_pipeline_first_stage(): inputs = fp32_to_float16(inputs, self.float16_convertor) Loading