Commit 0b0e37f0 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'inference' into 'main'

Inference refactoring

See merge request ADLR/megatron-lm!339
parents 5478d67e 71359e1f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They

## GPT Text Generation

We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`, `top-p`, and `greedy`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.

Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.

+13 −0
Original line number Diff line number Diff line
@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
    parser = _add_biencoder_args(parser)
    parser = _add_vit_args(parser)
    parser = _add_logging_args(parser)
    parser = _add_inference_args(parser)

    # Custom arguments.
    if extra_args_provider is not None:
@@ -279,6 +280,18 @@ def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


def _add_inference_args(parser):
    group = parser.add_argument_group(title='inference')

    group.add_argument('--inference-batch-times-seqlen-threshold',
                       type=int, default=512,
                       help='During inference, if batch-size times '
                       'sequence-length is smaller than this threshold '
                       'then we will not use pipelining, otherwise we will.')

    return parser

    
def _add_network_size_args(parser):
    group = parser.add_argument_group(title='network size')

+2 −5
Original line number Diff line number Diff line
@@ -82,16 +82,13 @@ class GPTModel(MegatronModule):
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
                tokentype_ids=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
                tokentype_ids=None, inference_params=None):

        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len)
            inference_params=inference_params)

        if self.post_process:
            return post_language_model_processing(
+3 −6
Original line number Diff line number Diff line
@@ -386,8 +386,7 @@ class TransformerLanguageModel(MegatronModule):
    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
                enc_dec_attn_mask=None, tokentype_ids=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None,
                inference_params=None,
                pooling_sequence_index=0,
                enc_hidden_states=None, output_enc_hidden=False):

@@ -404,8 +403,7 @@ class TransformerLanguageModel(MegatronModule):
                encoder_output = self.encoder(
                    encoder_input,
                    enc_attn_mask,
                    set_inference_key_value_memory=set_inference_key_value_memory,
                    inference_max_sequence_len=inference_max_sequence_len)
                    inference_params=inference_params)
            else:
                encoder_output = self.encoder_hidden_state
        else:
@@ -438,8 +436,7 @@ class TransformerLanguageModel(MegatronModule):
            dec_attn_mask,
            encoder_output=encoder_output,
            enc_dec_attn_mask=enc_dec_attn_mask,
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len)
            inference_params=inference_params)

        if self.add_pooler and self.post_process:
            return decoder_output, encoder_output, pooled_output
+4 −0
Original line number Diff line number Diff line
@@ -187,6 +187,10 @@ class Float16Module(MegatronModule):
        self.float16_convertor = float16_convertor


    def set_input_tensor(self, input_tensor):
        return self.module.set_input_tensor(input_tensor)


    def forward(self, *inputs, **kwargs):
        if mpu.is_pipeline_first_stage():
            inputs = fp32_to_float16(inputs, self.float16_convertor)
Loading