Loading megatron/text_generation_utils.py +1 −1 Original line number Diff line number Diff line Loading @@ -121,7 +121,7 @@ def receive_generate_info(): """ Needs to be synced up with send_generate_info """ input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.cuda.current_device()) input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.cuda.current_device()) torch.distributed.broadcast(input_info_tensor, 0) batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() Loading Loading
megatron/text_generation_utils.py +1 −1 Original line number Diff line number Diff line Loading @@ -121,7 +121,7 @@ def receive_generate_info(): """ Needs to be synced up with send_generate_info """ input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.cuda.current_device()) input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.cuda.current_device()) torch.distributed.broadcast(input_info_tensor, 0) batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() Loading