Commit cb57c380 authored by rprenger's avatar rprenger
Browse files

Fixing merge conflicts

parents 7bdeb1e7 87023abd
Loading
Loading
Loading
Loading
+6 −15
Original line number Diff line number Diff line
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal


def post_language_model_processing(lm_output, labels, logit_weights,
                                   get_key_value, parallel_output,
                                   forward_method_parallel_output,
                                   parallel_output,
                                   fp16_lm_cross_entropy):
    if get_key_value:
        lm_output, presents = lm_output

    # Output.
    if forward_method_parallel_output is not None:
        parallel_output = forward_method_parallel_output
    output = parallel_lm_logits(
        lm_output,
        logit_weights,
        parallel_output)

    if get_key_value:
        output = [output, presents]

    if labels is None:
        return output
    else:
@@ -90,23 +82,22 @@ class GPTModel(MegatronModule):
        self.language_model.set_input_tensor(input_tensor)

    def forward(self, input_ids, position_ids, attention_mask, labels=None,
                tokentype_ids=None, layer_past=None, get_key_value=False,
                forward_method_parallel_output=None):
                tokentype_ids=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):

        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
            layer_past=layer_past,
            get_key_value=get_key_value)
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len)

        if self.post_process:
            return post_language_model_processing(
                lm_output, labels,
                self.word_embeddings_weight(),
                get_key_value,
                self.parallel_output,
                forward_method_parallel_output,
                self.fp16_lm_cross_entropy)
        else:
            return lm_output
+16 −12
Original line number Diff line number Diff line
@@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule):

    def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
                dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
                enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
                get_key_value=False, pooling_sequence_index=0,
                enc_dec_attn_mask=None, tokentype_ids=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None,
                pooling_sequence_index=0,
                enc_hidden_states=None, output_enc_hidden=False):

        # Embeddings.
@@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule):

        # encoder.
        if enc_hidden_states is None:
            encoder_output = self.encoder(encoder_input,
            encoder_output = self.encoder(
                encoder_input,
                enc_attn_mask,
                                          layer_past=layer_past,
                                          get_key_value=get_key_value)
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len)
        else:
            encoder_output = enc_hidden_states.to(encoder_input.dtype)

@@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule):
        dec_embedding_output = self.embedding(dec_input_ids,
                                              dec_position_ids)
        # decoder
        decoder_output = self.decoder(dec_embedding_output,
        decoder_output = self.decoder(
            dec_embedding_output,
            dec_attn_mask,
                                      layer_past=layer_past,
                                      get_key_value=get_key_value,
            encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask)
            enc_dec_attn_mask=enc_dec_attn_mask,
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len)

        if self.add_pooler and self.post_process:
            return decoder_output, encoder_output, pooled_output
+87 −70
Original line number Diff line number Diff line
@@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule):
        self.layer_number = max(1, layer_number)
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
        self.params_dtype = args.params_dtype

        projection_size = args.kv_channels * args.num_attention_heads

@@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule):
            init_method=output_layer_init_method,
            skip_bias_add=True)

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False, encoder_output=None):
        # Inference key-value memory
        self.inference_key_memory = None
        self.inference_value_memory = None
        self.inference_current_sequence_len = 0


    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
        # hidden_states: [sq, b, h]


        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        if set_inference_key_value_memory:
            assert inference_max_sequence_len and inference_max_sequence_len > 0
            self.inference_key_memory = self._allocate_memory(
                inference_max_sequence_len, hidden_states.size(1))
            self.inference_value_memory = self._allocate_memory(
                inference_max_sequence_len, hidden_states.size(1))
            self.inference_current_sequence_len = 0
        # Some consistency check.
        if inference_max_sequence_len:
            assert self.inference_current_sequence_len < \
                self.inference_key_memory.size(0)
            assert inference_max_sequence_len == \
                self.inference_key_memory.size(0)
        # This is added for safety. In case inference_max_sequence_len
        # is not provided, make sure there is no potential memory left
        # from previous inference.
        if not inference_max_sequence_len:
            self.inference_key_memory = None
            self.inference_value_memory = None
        

        # =====================
        # Query, Key, and Value
        # =====================
@@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule):
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)

        # ==================================
        # Adjust key and value for inference
        # ==================================

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
                                   key_layer), dim=0)
            value_layer = torch.cat((past_value.type_as(value_layer),
                                     value_layer), dim=0)
        if get_key_value:
            present = (key_layer, value_layer)
        # ===================================================
        # Adjust key, value, and attention mask for inference
        # ===================================================

        if inference_max_sequence_len:
            # Adjust the range variables.
            start = self.inference_current_sequence_len
            self.inference_current_sequence_len += key_layer.size(0)
            end = self.inference_current_sequence_len
            # Copy key and values.
            self.inference_key_memory[start:end, ...] = key_layer
            self.inference_value_memory[start:end, ...] = value_layer
            key_layer = self.inference_key_memory[:end, ...]
            value_layer = self.inference_value_memory[:end, ...]
            # Adjust attention mask
            attention_mask = attention_mask[..., start:end, :end]


        # ===================================
        # Raw attention scores. [b, np, s, s]
@@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule):
        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
        # Update attention mask for inference. [b, np, sq, sk]
        # ==================================================

        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
                        attention_scores.size(3) - 1,
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

        # ===========================
        # Attention probs and dropout
@@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule):

        output, bias = self.dense(context_layer)

        if get_key_value:
            output = [output, present]

        return output, bias


@@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule):
                               output_layer_init_method)

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
                encoder_output=None,
                enc_dec_attn_mask=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
        # hidden_states: [b, s, h]

        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output, attention_bias = \
            self.self_attention(layernorm_output,
            self.self_attention(
                layernorm_output,
                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)

        if get_key_value:
            attention_output, presents = attention_output
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len)

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
@@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule):
                residual,
                self.hidden_dropout)

        if get_key_value:
            output = [output, presents]

        return output


@@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule):
        forward_step_func"""
        self.input_tensor = input_tensor

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                enc_dec_attn_mask=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):

        # Checks.
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
        if inference_max_sequence_len:
            assert self.activations_checkpoint_method is None, \
                'get_key_value does not work with ' \
                'activation checkpointing'
                'inference does not work with activation checkpointing'

        if self.pre_process:
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
@@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule):
                                                       encoder_output,
                                                       enc_dec_attn_mask)
        else:
            if get_key_value:
                presents = []
            for index in range(self.num_layers):
                layer = self._get_layer(index)
                past = None
                if layer_past is not None:
                    past = layer_past[index]
                hidden_states = layer(hidden_states,
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
                    set_inference_key_value_memory=set_inference_key_value_memory,
                    inference_max_sequence_len=inference_max_sequence_len)

        # Final layer norm.
        if self.post_process:
@@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule):
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
        if get_key_value:
            output = [output, presents]
        
        return output
+8 −1
Original line number Diff line number Diff line
@@ -58,6 +58,13 @@ class MegatronGenerate(Resource):
            if not isinstance(all_probs, bool):
                return "all_probs must be a boolean value"
        
        temperature = args.temperature
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
            if not isinstance(temperature, float) or not \
               0.0 < temperature <= 100.0:
                return "temperature must be a positive float less than or equal to 100.0"
        
        add_BOS = False
        if "add_BOS" in request.get_json():
            add_BOS = request.get_json()["add_BOS"]
@@ -66,7 +73,7 @@ class MegatronGenerate(Resource):

        sem.acquire()  # Need to get lock to keep multiple threads from hitting code
        MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
        resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, add_BOS) 
        resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS) 
        sem.release()
        
        if all_probs:
+34 −29
Original line number Diff line number Diff line
@@ -141,14 +141,15 @@ def receive_generate_info():
    
    return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs

def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs):
def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature):
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids,
                                                 tokens_to_generate,
                                                 all_probs)
                                                 all_probs,
                                                 temperature=temperature)
    for tokens, lengths, output_logits, full_logits in batch_token_iterator:
        context_length += 1
                
@@ -177,7 +178,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
    if tokens is not None:
        return tokens[:, :context_length], output_logits, full_logits 

def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_BOS=False):
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False):
    model.eval()
    if torch.distributed.get_rank() == 0:
        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
@@ -185,8 +186,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_B
    else:
        context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()

    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
    
    output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature)
    if output is not None:
        decode_tokens, output_logits, full_logits = output
        
@@ -230,8 +230,8 @@ def switch(val1, val2, boolean):


def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):
                 set_inference_key_value_memory=False,
                 inference_max_sequence_len=None):

    # Hidden size changes when not using recompute, need to tell p2p_communicate
    # functions the correct size
@@ -246,26 +246,22 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
    output_tensor = model(tokens, position_ids, attention_mask,
    output_tensor = model(
        tokens, position_ids, attention_mask,
        tokentype_ids=tokentype_ids,
                          layer_past=layer_past,
                          get_key_value=get_key_value,
                          forward_method_parallel_output=forward_method_parallel_output)

    if get_key_value:
        output_tensor, layer_past = output_tensor
        set_inference_key_value_memory=set_inference_key_value_memory,
        inference_max_sequence_len=inference_max_sequence_len)

    send_forward(output_tensor)

    args.seq_length = orig_seq_length
    if get_key_value:
        return output_tensor, layer_past

    return output_tensor


def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          tokens_to_generate, all_probs=False, type_ids=None):
                          tokens_to_generate, all_probs=False, type_ids=None, temperature=None):
    args = get_args()
    tokenizer = get_tokenizer()

@@ -282,7 +278,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,

        counter = 0

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
@@ -299,11 +294,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
        while context_length < maxlen:
            types2use = None
            if counter == 0:
                # Allocate memory for the entire context.
                set_inference_key_value_memory = True
                tokens2use = tokens[:, :context_length]
                positions2use = position_ids[:, :context_length]
                if type_ids is not None:
                    types2use = type_ids[:, :context_length]
            else:
                # Set this to false so the memory is not reallocated.
                set_inference_key_value_memory = False
                tokens2use = tokens[:, context_length - 1].view(
                    batch_size, -1)
                positions2use = position_ids[:, context_length - 1].view(
@@ -311,29 +310,35 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
                if type_ids is not None:
                    types2use = type_ids[:, context_length - 1].view(
                        batch_size, -1)
            output, layer_past = forward_step(model, tokens2use,
            
            output = forward_step(
                model, tokens2use,
                positions2use,
                attention_mask,
                                              layer_past=layer_past,
                                              get_key_value=True,
                                              tokentype_ids=types2use,
                                              forward_method_parallel_output=False)
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=maxlen,
                tokentype_ids=types2use)

            if mpu.is_pipeline_last_stage():
                assert output is not None
                output = output.float()
                logits = output[:, -1].view(batch_size, -1).contiguous()

            if mpu.is_pipeline_last_stage():
                if args.greedy:
                    prev = torch.argmax(logits, dim=-1).view(-1)
                else:
                    logits = logits.float()
                    logits /= args.temperature
                    logits /= temperature
                    logits = top_k_logits(logits, top_k=args.top_k,
                                          top_p=args.top_p)
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(-1)
                started = context_lengths <= context_length

                # Clamp the out of vocabulary tokens.
                tokenizer = get_tokenizer()
                prev = torch.clamp(prev, max=tokenizer.vocab_size - 1)

                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
Loading