Loading megatron/text_generation_utils.py +11 −10 Original line number Diff line number Diff line Loading @@ -108,13 +108,13 @@ def tokenize_batch(sentences, max_len, add_BOS): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): """ Needs to be synced up with receive_generate_info """ # Send the sizes of the tensors input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs] input_info_tensor = torch.cuda.LongTensor(input_info) input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature] input_info_tensor = torch.cuda.FloatTensor(input_info) torch.distributed.broadcast(input_info_tensor, 0) # Send variables to all ranks Loading @@ -125,12 +125,13 @@ def receive_generate_info(): """ Needs to be synced up with send_generate_info """ input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.cuda.current_device()) input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device()) torch.distributed.broadcast(input_info_tensor, 0) batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() tokens_to_generate = input_info_tensor[2].item() all_probs = input_info_tensor[3].item() batch_size = int(input_info_tensor[0].item()) seq_len = int(input_info_tensor[1].item()) tokens_to_generate = int(input_info_tensor[2].item()) all_probs = int(input_info_tensor[3].item()) temperature = float(input_info_tensor[4].item()) context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device()) Loading @@ -139,7 +140,7 @@ def receive_generate_info(): torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0) return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): context_length = context_length_tensor.min().item() Loading Loading @@ -182,7 +183,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() Loading Loading
megatron/text_generation_utils.py +11 −10 Original line number Diff line number Diff line Loading @@ -108,13 +108,13 @@ def tokenize_batch(sentences, max_len, add_BOS): context_length_tensor = torch.cuda.LongTensor(context_lengths) return context_tokens_tensor, context_length_tensor def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): def send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): """ Needs to be synced up with receive_generate_info """ # Send the sizes of the tensors input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs] input_info_tensor = torch.cuda.LongTensor(input_info) input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), tokens_to_generate, all_probs, temperature] input_info_tensor = torch.cuda.FloatTensor(input_info) torch.distributed.broadcast(input_info_tensor, 0) # Send variables to all ranks Loading @@ -125,12 +125,13 @@ def receive_generate_info(): """ Needs to be synced up with send_generate_info """ input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.cuda.current_device()) input_info_tensor = torch.empty(5, dtype=torch.float32, device=torch.cuda.current_device()) torch.distributed.broadcast(input_info_tensor, 0) batch_size = input_info_tensor[0].item() seq_len = input_info_tensor[1].item() tokens_to_generate = input_info_tensor[2].item() all_probs = input_info_tensor[3].item() batch_size = int(input_info_tensor[0].item()) seq_len = int(input_info_tensor[1].item()) tokens_to_generate = int(input_info_tensor[2].item()) all_probs = int(input_info_tensor[3].item()) temperature = float(input_info_tensor[4].item()) context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device()) Loading @@ -139,7 +140,7 @@ def receive_generate_info(): torch.distributed.broadcast(context_length_tensor, 0) torch.distributed.broadcast(context_tokens_tensor, 0) return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs, temperature def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): context_length = context_length_tensor.min().item() Loading Loading @@ -182,7 +183,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() Loading