Commit 8b9fe87b authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Inference context optimization

parent f47aa770
Loading
Loading
Loading
Loading
+6 −15
Original line number Diff line number Diff line
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal


def post_language_model_processing(lm_output, labels, logit_weights,
                                   get_key_value, parallel_output,
                                   forward_method_parallel_output,
                                   parallel_output,
                                   fp16_lm_cross_entropy):
    if get_key_value:
        lm_output, presents = lm_output

    # Output.
    if forward_method_parallel_output is not None:
        parallel_output = forward_method_parallel_output
    output = parallel_lm_logits(
        lm_output,
        logit_weights,
        parallel_output)

    if get_key_value:
        output = [output, presents]

    if labels is None:
        return output
    else:
@@ -90,23 +82,22 @@ class GPTModel(MegatronModule):
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):
                tokentype_ids=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):

        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
            layer_past=layer_past,
            get_key_value=get_key_value)
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len)

        if self.post_process:
            return post_language_model_processing(
                lm_output, labels,
                self.word_embeddings_weight(),
                get_key_value,
                self.parallel_output,
                forward_method_parallel_output,
                self.fp16_lm_cross_entropy)
        else:
            return lm_output
+16 −12
Original line number Diff line number Diff line
@@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule):

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
                enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
                get_key_value=False, pooling_sequence_index=0,
                enc_dec_attn_mask=None, tokentype_ids=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None,
                pooling_sequence_index=0,
                enc_hidden_states=None, output_enc_hidden=False):

        # Embeddings.
@@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule):

        # encoder.
        if enc_hidden_states is None:
            encoder_output = self.encoder(encoder_input,
            encoder_output = self.encoder(
                encoder_input,
                enc_attn_mask,
                                          layer_past=layer_past,
                                          get_key_value=get_key_value)
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len)
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

@@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule):
        dec_embedding_output = self.embedding(dec_input_ids,
                                              dec_position_ids)
        # decoder
        decoder_output = self.decoder(dec_embedding_output,
        decoder_output = self.decoder(
            dec_embedding_output,
            dec_attn_mask,
                                      layer_past=layer_past,
                                      get_key_value=get_key_value,
            encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask)
            enc_dec_attn_mask=enc_dec_attn_mask,
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len)

        if self.add_pooler and self.post_process:
            return decoder_output, encoder_output, pooled_output
+87 −70
Original line number Diff line number Diff line
@@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule):
        self.layer_number = max(1, layer_number)
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
        self.params_dtype = args.params_dtype

        projection_size = args.kv_channels * args.num_attention_heads

@@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule):
            init_method=output_layer_init_method,
            skip_bias_add=True)

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False, encoder_output=None):
        # Inference key-value memory
        self.inference_key_memory = None
        self.inference_value_memory = None
        self.inference_current_sequence_len = 0


    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
        # hidden_states: [sq, b, h]


        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        if set_inference_key_value_memory:
            assert inference_max_sequence_len and inference_max_sequence_len > 0
            self.inference_key_memory = self._allocate_memory(
                inference_max_sequence_len, hidden_states.size(1))
            self.inference_value_memory = self._allocate_memory(
                inference_max_sequence_len, hidden_states.size(1))
            self.inference_current_sequence_len = 0
        # Some consistency check.
        if inference_max_sequence_len:
            assert self.inference_current_sequence_len < \
                self.inference_key_memory.size(0)
            assert inference_max_sequence_len == \
                self.inference_key_memory.size(0)
        # This is added for safety. In case inference_max_sequence_len
        # is not provided, make sure there is no potential memory left
        # from previous inference.
        if not inference_max_sequence_len:
            self.inference_key_memory = None
            self.inference_value_memory = None
        

        # =====================
        # Query, Key, and Value
        # =====================
@@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule):
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)

        # ==================================
        # Adjust key and value for inference
        # ==================================

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
                                   key_layer), dim=0)
            value_layer = torch.cat((past_value.type_as(value_layer),
                                     value_layer), dim=0)
        if get_key_value:
            present = (key_layer, value_layer)
        # ===================================================
        # Adjust key, value, and attention mask for inference
        # ===================================================

        if inference_max_sequence_len:
            # Adjust the range variables.
            start = self.inference_current_sequence_len
            self.inference_current_sequence_len += key_layer.size(0)
            end = self.inference_current_sequence_len
            # Copy key and values.
            self.inference_key_memory[start:end, ...] = key_layer
            self.inference_value_memory[start:end, ...] = value_layer
            key_layer = self.inference_key_memory[:end, ...]
            value_layer = self.inference_value_memory[:end, ...]
            # Adjust attention mask
            attention_mask = attention_mask[..., start:end, :end]


        # ===================================
        # Raw attention scores. [b, np, s, s]
@@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule):
        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
        # Update attention mask for inference. [b, np, sq, sk]
        # ==================================================

        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
                        attention_scores.size(3) - 1,
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

        # ===========================
        # Attention probs and dropout
@@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule):

        output, bias = self.dense(context_layer)

        if get_key_value:
            output = [output, present]

        return output, bias


@@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule):
                               output_layer_init_method)

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
                encoder_output=None,
                enc_dec_attn_mask=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
        # hidden_states: [b, s, h]

        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output, attention_bias = \
            self.self_attention(layernorm_output,
            self.self_attention(
                layernorm_output,
                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)

        if get_key_value:
            attention_output, presents = attention_output
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len)

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
@@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule):
                residual,
                self.hidden_dropout)

        if get_key_value:
            output = [output, presents]

        return output


@@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule):
        forward_step_func"""
        self.input_tensor = input_tensor

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                enc_dec_attn_mask=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):

        # Checks.
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
        if inference_max_sequence_len:
            assert self.activations_checkpoint_method is None, \
                'get_key_value does not work with ' \
                'activation checkpointing'
                'inference does not work with activation checkpointing'

        if self.pre_process:
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
@@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule):
                                                       encoder_output,
                                                       enc_dec_attn_mask)
        else:
            if get_key_value:
                presents = []
            for index in range(self.num_layers):
                layer = self._get_layer(index)
                past = None
                if layer_past is not None:
                    past = layer_past[index]
                hidden_states = layer(hidden_states,
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
                    set_inference_key_value_memory=set_inference_key_value_memory,
                    inference_max_sequence_len=inference_max_sequence_len)

        # Final layer norm.
        if self.post_process:
@@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule):
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
        if get_key_value:
            output = [output, presents]
        
        return output
+26 −21
Original line number Diff line number Diff line
@@ -227,8 +227,8 @@ def switch(val1, val2, boolean):


def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):
                 set_inference_key_value_memory=False,
                 inference_max_sequence_len=None):

    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
@@ -243,20 +243,16 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
    output_tensor = model(tokens, position_ids, attention_mask,
    output_tensor = model(
        tokens, position_ids, attention_mask,
        tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)

    if get_key_value:
        output_tensor, layer_past = output_tensor
        set_inference_key_value_memory=set_inference_key_value_memory,
        inference_max_sequence_len=inference_max_sequence_len)

    send_forward(output_tensor)

    args.seq_length = orig_seq_length
    if get_key_value:
        return output_tensor, layer_past

    return output_tensor


@@ -279,7 +275,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,

        counter = 0

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
@@ -296,11 +291,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
        while context_length < maxlen:
            types2use = None
            if counter == 0:
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                if type_ids is not None:
                    types2use = type_ids[:, :context_length]
            else:
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
                tokens2use = tokens[:, context_length - 1].view(
                    batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
@@ -308,18 +307,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                if type_ids is not None:
                    types2use = type_ids[:, context_length - 1].view(
                        batch_size, -1)
            output, layer_past = forward_step(model, tokens2use,
            
            output = forward_step(
                model, tokens2use,
                positions2use,
                attention_mask,
                                              layer_past=layer_past,
                                              get_key_value=True,
                                              tokentype_ids=types2use,
                                              forward_method_parallel_output=False)
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=maxlen,
                tokentype_ids=types2use)

            if mpu.is_pipeline_last_stage():
                assert output is not None
                output = output.float()
                logits = output[:, -1].view(batch_size, -1).contiguous()

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
@@ -331,6 +332,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length

                # Clamp the out of vocabulary tokens.
                tokenizer = get_tokenizer()
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
+18 −16
Original line number Diff line number Diff line
@@ -189,7 +189,7 @@ def update_train_iters(args):
    print_rank_0('setting training iterations to {}'.format(args.train_iters))


def get_model(model_provider_func):
def get_model(model_provider_func, wrap_with_ddp=True):
    """Build the model."""
    args = get_args()

@@ -243,22 +243,24 @@ def get_model(model_provider_func):
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]

    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]
        return model

    if args.DDP_impl == 'local':
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
        return model

    raise NotImplementedError('Unknown DDP implementation specified: {}. '
                              'Exiting.'.format(args.DDP_impl))
        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))

    return model


def get_learning_rate_scheduler(optimizer):
Loading