Commit 1d4e8760 authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Fix text generation without recompute

parent 2623551d
Loading
Loading
Loading
Loading
+22 −15
Original line number Diff line number Diff line
@@ -138,23 +138,23 @@ def generate_samples_input_from_file(model):
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
            context_length = input_info_tensor[2].item()

            if terminate_runs == 1:
                return

            # For pipeline parallel we send context tokens to last stage
            # so it knows when to start overwriting
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
                    group = mpu.get_embedding_group()
                    group = mpu.get_pipeline_model_parallel_group()
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                if mpu.is_pipeline_last_stage():
                else:
                    src = mpu.get_pipeline_model_parallel_first_rank()
                    group = mpu.get_embedding_group()
                    context_length = input_info_tensor[2].item()
                    group = mpu.get_pipeline_model_parallel_group()
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
@@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24):
                                         group=mpu.get_model_parallel_group())
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
            context_length = input_info_tensor[2].item()

            if terminate_runs == 1:
                return

            # For pipeline parallel we send context tokens to last stage
            # so it knows when to start overwriting
            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
            if mpu.get_tensor_model_parallel_rank() == 0 \
               and args.pipeline_model_parallel_size > 1:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_first_rank()
                    group = mpu.get_embedding_group()
                    group = mpu.get_pipeline_model_parallel_group()
                    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
                    torch.distributed.broadcast(context_tokens_tensor, src, group)
                if mpu.is_pipeline_last_stage():
                else:
                    src = mpu.get_pipeline_model_parallel_first_rank()
                    group = mpu.get_embedding_group()
                    context_length = input_info_tensor[2].item()
                    group = mpu.get_pipeline_model_parallel_group()
                    context_tokens_tensor = torch.empty(context_length,
                                                        dtype=torch.int64,
                                                        device=torch.device("cuda"))
@@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24):
                    context_tokens = context_tokens_tensor.cpu().numpy().tolist()

            token_stream = get_token_stream(model, [context_tokens])

            for counter, decode_tokens in enumerate(token_stream):
                if counter % print_frequency != 0 \
                   or mpu.get_tensor_model_parallel_rank() != 0 \
@@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):

    # Hidden size changes when not using recompute, need to tell communicate()
    # the correct size
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

    if not mpu.is_pipeline_first_stage():
        input_tensor, _ = communicate(
            tensor_send_next=None,
@@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)
        return None

    args.seq_length = orig_seq_length
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor
@@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    if type_ids is not None:
                        types2use = type_ids[:, context_length - 1].view(
                            batch_size, -1)
                logits, layer_past = forward_step(model, tokens2use,
                output, layer_past = forward_step(model, tokens2use,
                                                  positions2use,
                                                  attention_mask,
                                                  layer_past=layer_past,
@@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                                                  forward_method_parallel_output=False)
                if mpu.is_pipeline_last_stage():
                    assert output is not None
                    logits = logits[:, -1].view(batch_size, -1).contiguous()
                    logits = output[:, -1].view(batch_size, -1).contiguous()

            if mpu.is_pipeline_last_stage():
                if args.greedy: