Loading megatron/text_generation_utils.py +22 −15 Original line number Diff line number Diff line Loading @@ -138,23 +138,23 @@ def generate_samples_input_from_file(model): group=mpu.get_model_parallel_group()) terminate_runs = input_info_tensor[0].item() raw_text_len = input_info_tensor[1].item() context_length = input_info_tensor[2].item() if terminate_runs == 1: return # For pipeline parallel we send context tokens to last stage # so it knows when to start overwriting # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if mpu.get_tensor_model_parallel_rank() == 0 \ and args.pipeline_model_parallel_size > 1: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_embedding_group() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.cuda.LongTensor(context_tokens) torch.distributed.broadcast(context_tokens_tensor, src, group) if mpu.is_pipeline_last_stage(): else: src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_embedding_group() context_length = input_info_tensor[2].item() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.empty(context_length, dtype=torch.int64, device=torch.device("cuda")) Loading Loading @@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24): group=mpu.get_model_parallel_group()) terminate_runs = input_info_tensor[0].item() raw_text_len = input_info_tensor[1].item() context_length = input_info_tensor[2].item() if terminate_runs == 1: return # For pipeline parallel we send context tokens to last stage # so it knows when to start overwriting # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if mpu.get_tensor_model_parallel_rank() == 0 \ and args.pipeline_model_parallel_size > 1: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_embedding_group() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.cuda.LongTensor(context_tokens) torch.distributed.broadcast(context_tokens_tensor, src, group) if mpu.is_pipeline_last_stage(): else: src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_embedding_group() context_length = input_info_tensor[2].item() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.empty(context_length, dtype=torch.int64, device=torch.device("cuda")) Loading @@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24): context_tokens = context_tokens_tensor.cpu().numpy().tolist() token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): if counter % print_frequency != 0 \ or mpu.get_tensor_model_parallel_rank() != 0 \ Loading Loading @@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, layer_past=None, get_key_value=None, forward_method_parallel_output=None): # Hidden size changes when not using recompute, need to tell communicate() # the correct size args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate( tensor_send_next=None, Loading Loading @@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, tensor_send_prev=None, recv_forward=False, recv_backward=False) return None args.seq_length = orig_seq_length if get_key_value: return output_tensor, layer_past return output_tensor Loading Loading @@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) logits, layer_past = forward_step(model, tokens2use, output, layer_past = forward_step(model, tokens2use, positions2use, attention_mask, layer_past=layer_past, Loading @@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, forward_method_parallel_output=False) if mpu.is_pipeline_last_stage(): assert output is not None logits = logits[:, -1].view(batch_size, -1).contiguous() logits = output[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): if args.greedy: Loading Loading
megatron/text_generation_utils.py +22 −15 Original line number Diff line number Diff line Loading @@ -138,23 +138,23 @@ def generate_samples_input_from_file(model): group=mpu.get_model_parallel_group()) terminate_runs = input_info_tensor[0].item() raw_text_len = input_info_tensor[1].item() context_length = input_info_tensor[2].item() if terminate_runs == 1: return # For pipeline parallel we send context tokens to last stage # so it knows when to start overwriting # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if mpu.get_tensor_model_parallel_rank() == 0 \ and args.pipeline_model_parallel_size > 1: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_embedding_group() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.cuda.LongTensor(context_tokens) torch.distributed.broadcast(context_tokens_tensor, src, group) if mpu.is_pipeline_last_stage(): else: src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_embedding_group() context_length = input_info_tensor[2].item() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.empty(context_length, dtype=torch.int64, device=torch.device("cuda")) Loading Loading @@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24): group=mpu.get_model_parallel_group()) terminate_runs = input_info_tensor[0].item() raw_text_len = input_info_tensor[1].item() context_length = input_info_tensor[2].item() if terminate_runs == 1: return # For pipeline parallel we send context tokens to last stage # so it knows when to start overwriting # For pipeline parallel we send context tokens to other stages # so they get the lengths correct if mpu.get_tensor_model_parallel_rank() == 0 \ and args.pipeline_model_parallel_size > 1: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_embedding_group() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.cuda.LongTensor(context_tokens) torch.distributed.broadcast(context_tokens_tensor, src, group) if mpu.is_pipeline_last_stage(): else: src = mpu.get_pipeline_model_parallel_first_rank() group = mpu.get_embedding_group() context_length = input_info_tensor[2].item() group = mpu.get_pipeline_model_parallel_group() context_tokens_tensor = torch.empty(context_length, dtype=torch.int64, device=torch.device("cuda")) Loading @@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24): context_tokens = context_tokens_tensor.cpu().numpy().tolist() token_stream = get_token_stream(model, [context_tokens]) for counter, decode_tokens in enumerate(token_stream): if counter % print_frequency != 0 \ or mpu.get_tensor_model_parallel_rank() != 0 \ Loading Loading @@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, layer_past=None, get_key_value=None, forward_method_parallel_output=None): # Hidden size changes when not using recompute, need to tell communicate() # the correct size args = get_args() orig_seq_length = args.seq_length args.seq_length = tokens.shape[1] if not mpu.is_pipeline_first_stage(): input_tensor, _ = communicate( tensor_send_next=None, Loading Loading @@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, tensor_send_prev=None, recv_forward=False, recv_backward=False) return None args.seq_length = orig_seq_length if get_key_value: return output_tensor, layer_past return output_tensor Loading Loading @@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) logits, layer_past = forward_step(model, tokens2use, output, layer_past = forward_step(model, tokens2use, positions2use, attention_mask, layer_past=layer_past, Loading @@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, forward_method_parallel_output=False) if mpu.is_pipeline_last_stage(): assert output is not None logits = logits[:, -1].view(batch_size, -1).contiguous() logits = output[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): if args.greedy: Loading