Loading megatron/model/gpt_model.py +6 −15 Original line number Diff line number Diff line Loading @@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal def post_language_model_processing(lm_output, labels, logit_weights, get_key_value, parallel_output, forward_method_parallel_output, parallel_output, fp16_lm_cross_entropy): if get_key_value: lm_output, presents = lm_output # Output. if forward_method_parallel_output is not None: parallel_output = forward_method_parallel_output output = parallel_lm_logits( lm_output, logit_weights, parallel_output) if get_key_value: output = [output, presents] if labels is None: return output else: Loading Loading @@ -90,23 +82,22 @@ class GPTModel(MegatronModule): self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): lm_output = self.language_model( input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) if self.post_process: return post_language_model_processing( lm_output, labels, self.word_embeddings_weight(), get_key_value, self.parallel_output, forward_method_parallel_output, self.fp16_lm_cross_entropy) else: return lm_output Loading megatron/model/language_model.py +16 −12 Original line number Diff line number Diff line Loading @@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule): def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, get_key_value=False, pooling_sequence_index=0, enc_dec_attn_mask=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): # Embeddings. Loading @@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule): # encoder. if enc_hidden_states is None: encoder_output = self.encoder(encoder_input, encoder_output = self.encoder( encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) else: encoder_output = enc_hidden_states.to(encoder_input.dtype) Loading @@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule): dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids) # decoder decoder_output = self.decoder(dec_embedding_output, decoder_output = self.decoder( dec_embedding_output, dec_attn_mask, layer_past=layer_past, get_key_value=get_key_value, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask) enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output Loading megatron/model/transformer.py +87 −70 Original line number Diff line number Diff line Loading @@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule): self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype projection_size = args.kv_channels * args.num_attention_heads Loading Loading @@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None self.inference_current_sequence_len = 0 def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward(self, hidden_states, attention_mask, encoder_output=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_value_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_current_sequence_len = 0 # Some consistency check. if inference_max_sequence_len: assert self.inference_current_sequence_len < \ self.inference_key_memory.size(0) assert inference_max_sequence_len == \ self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len # is not provided, make sure there is no potential memory left # from previous inference. if not inference_max_sequence_len: self.inference_key_memory = None self.inference_value_memory = None # ===================== # Query, Key, and Value # ===================== Loading Loading @@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule): self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================================== # Adjust key, value, and attention mask for inference # =================================================== if inference_max_sequence_len: # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) end = self.inference_current_sequence_len # Copy key and values. self.inference_key_memory[start:end, ...] = key_layer self.inference_value_memory[start:end, ...] = value_layer key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask attention_mask = attention_mask[..., start:end, :end] # =================================== # Raw attention scores. [b, np, s, s] Loading Loading @@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule): # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout Loading Loading @@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule): output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias Loading Loading @@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule): output_layer_init_method) def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False): encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention(layernorm_output, self.self_attention( layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: attention_output, presents = attention_output set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) # Residual connection. if self.apply_residual_connection_post_layernorm: Loading Loading @@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule): residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output Loading Loading @@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule): forward_step_func""" self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # Checks. if layer_past is not None: assert get_key_value, \ 'for not None values in layer_past, ' \ 'expected get_key_value to be set' if get_key_value: if inference_max_sequence_len: assert self.activations_checkpoint_method is None, \ 'get_key_value does not work with ' \ 'activation checkpointing' 'inference does not work with activation checkpointing' if self.pre_process: # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. Loading @@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule): encoder_output, enc_dec_attn_mask) else: if get_key_value: presents = [] for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[index] hidden_states = layer(hidden_states, hidden_states = layer( hidden_states, attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, layer_past=past, get_key_value=get_key_value) if get_key_value: hidden_states, present = hidden_states presents.append(present) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) # Final layer norm. if self.post_process: Loading @@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule): output = self.final_layernorm(hidden_states) else: output = hidden_states if get_key_value: output = [output, presents] return output megatron/text_generation_server.py +8 −1 Original line number Diff line number Diff line Loading @@ -58,6 +58,13 @@ class MegatronGenerate(Resource): if not isinstance(all_probs, bool): return "all_probs must be a boolean value" temperature = args.temperature if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not isinstance(temperature, float) or not \ 0.0 < temperature <= 100.0: return "temperature must be a positive float less than or equal to 100.0" add_BOS = False if "add_BOS" in request.get_json(): add_BOS = request.get_json()["add_BOS"] Loading @@ -66,7 +73,7 @@ class MegatronGenerate(Resource): sem.acquire() # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, add_BOS) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS) sem.release() if all_probs: Loading megatron/text_generation_utils.py +34 −29 Original line number Diff line number Diff line Loading @@ -141,14 +141,15 @@ def receive_generate_info(): return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokens_to_generate, all_probs) all_probs, temperature=temperature) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 Loading Loading @@ -177,7 +178,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ if tokens is not None: return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_BOS=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) Loading @@ -185,8 +186,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_B else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) if output is not None: decode_tokens, output_logits, full_logits = output Loading Loading @@ -230,8 +230,8 @@ def switch(val1, val2, boolean): def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, layer_past=None, get_key_value=None, forward_method_parallel_output=None): set_inference_key_value_memory=False, inference_max_sequence_len=None): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size Loading @@ -246,26 +246,22 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model( tokens, position_ids, attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output) if get_key_value: output_tensor, layer_past = output_tensor set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) send_forward(output_tensor) args.seq_length = orig_seq_length if get_key_value: return output_tensor, layer_past return output_tensor def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None): tokens_to_generate, all_probs=False, type_ids=None, temperature=None): args = get_args() tokenizer = get_tokenizer() Loading @@ -282,7 +278,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, counter = 0 layer_past = None batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens Loading @@ -299,11 +294,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, while context_length < maxlen: types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] if type_ids is not None: types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view( batch_size, -1) positions2use = position_ids[:, context_length - 1].view( Loading @@ -311,29 +310,35 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) output, layer_past = forward_step(model, tokens2use, output = forward_step( model, tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use, forward_method_parallel_output=False) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=maxlen, tokentype_ids=types2use) if mpu.is_pipeline_last_stage(): assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits /= temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the out of vocabulary tokens. tokenizer = get_tokenizer() prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch( tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens Loading Loading
megatron/model/gpt_model.py +6 −15 Original line number Diff line number Diff line Loading @@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal def post_language_model_processing(lm_output, labels, logit_weights, get_key_value, parallel_output, forward_method_parallel_output, parallel_output, fp16_lm_cross_entropy): if get_key_value: lm_output, presents = lm_output # Output. if forward_method_parallel_output is not None: parallel_output = forward_method_parallel_output output = parallel_lm_logits( lm_output, logit_weights, parallel_output) if get_key_value: output = [output, presents] if labels is None: return output else: Loading Loading @@ -90,23 +82,22 @@ class GPTModel(MegatronModule): self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): lm_output = self.language_model( input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) if self.post_process: return post_language_model_processing( lm_output, labels, self.word_embeddings_weight(), get_key_value, self.parallel_output, forward_method_parallel_output, self.fp16_lm_cross_entropy) else: return lm_output Loading
megatron/model/language_model.py +16 −12 Original line number Diff line number Diff line Loading @@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule): def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, get_key_value=False, pooling_sequence_index=0, enc_dec_attn_mask=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): # Embeddings. Loading @@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule): # encoder. if enc_hidden_states is None: encoder_output = self.encoder(encoder_input, encoder_output = self.encoder( encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) else: encoder_output = enc_hidden_states.to(encoder_input.dtype) Loading @@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule): dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids) # decoder decoder_output = self.decoder(dec_embedding_output, decoder_output = self.decoder( dec_embedding_output, dec_attn_mask, layer_past=layer_past, get_key_value=get_key_value, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask) enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output Loading
megatron/model/transformer.py +87 −70 Original line number Diff line number Diff line Loading @@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule): self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype projection_size = args.kv_channels * args.num_attention_heads Loading Loading @@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None self.inference_current_sequence_len = 0 def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward(self, hidden_states, attention_mask, encoder_output=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_value_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_current_sequence_len = 0 # Some consistency check. if inference_max_sequence_len: assert self.inference_current_sequence_len < \ self.inference_key_memory.size(0) assert inference_max_sequence_len == \ self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len # is not provided, make sure there is no potential memory left # from previous inference. if not inference_max_sequence_len: self.inference_key_memory = None self.inference_value_memory = None # ===================== # Query, Key, and Value # ===================== Loading Loading @@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule): self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================================== # Adjust key, value, and attention mask for inference # =================================================== if inference_max_sequence_len: # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) end = self.inference_current_sequence_len # Copy key and values. self.inference_key_memory[start:end, ...] = key_layer self.inference_value_memory[start:end, ...] = value_layer key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask attention_mask = attention_mask[..., start:end, :end] # =================================== # Raw attention scores. [b, np, s, s] Loading Loading @@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule): # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout Loading Loading @@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule): output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias Loading Loading @@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule): output_layer_init_method) def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False): encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention(layernorm_output, self.self_attention( layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: attention_output, presents = attention_output set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) # Residual connection. if self.apply_residual_connection_post_layernorm: Loading Loading @@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule): residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output Loading Loading @@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule): forward_step_func""" self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # Checks. if layer_past is not None: assert get_key_value, \ 'for not None values in layer_past, ' \ 'expected get_key_value to be set' if get_key_value: if inference_max_sequence_len: assert self.activations_checkpoint_method is None, \ 'get_key_value does not work with ' \ 'activation checkpointing' 'inference does not work with activation checkpointing' if self.pre_process: # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. Loading @@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule): encoder_output, enc_dec_attn_mask) else: if get_key_value: presents = [] for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[index] hidden_states = layer(hidden_states, hidden_states = layer( hidden_states, attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, layer_past=past, get_key_value=get_key_value) if get_key_value: hidden_states, present = hidden_states presents.append(present) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) # Final layer norm. if self.post_process: Loading @@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule): output = self.final_layernorm(hidden_states) else: output = hidden_states if get_key_value: output = [output, presents] return output
megatron/text_generation_server.py +8 −1 Original line number Diff line number Diff line Loading @@ -58,6 +58,13 @@ class MegatronGenerate(Resource): if not isinstance(all_probs, bool): return "all_probs must be a boolean value" temperature = args.temperature if "temperature" in request.get_json(): temperature = request.get_json()["temperature"] if not isinstance(temperature, float) or not \ 0.0 < temperature <= 100.0: return "temperature must be a positive float less than or equal to 100.0" add_BOS = False if "add_BOS" in request.get_json(): add_BOS = request.get_json()["add_BOS"] Loading @@ -66,7 +73,7 @@ class MegatronGenerate(Resource): sem.acquire() # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, add_BOS) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS) sem.release() if all_probs: Loading
megatron/text_generation_utils.py +34 −29 Original line number Diff line number Diff line Loading @@ -141,14 +141,15 @@ def receive_generate_info(): return context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs): def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature): context_length = context_length_tensor.min().item() tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokens_to_generate, all_probs) all_probs, temperature=temperature) for tokens, lengths, output_logits, full_logits in batch_token_iterator: context_length += 1 Loading Loading @@ -177,7 +178,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ if tokens is not None: return tokens[:, :context_length], output_logits, full_logits def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_BOS=False): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False): model.eval() if torch.distributed.get_rank() == 0: context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS) Loading @@ -185,8 +186,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_B else: context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) output = synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs, temperature) if output is not None: decode_tokens, output_logits, full_logits = output Loading Loading @@ -230,8 +230,8 @@ def switch(val1, val2, boolean): def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, layer_past=None, get_key_value=None, forward_method_parallel_output=None): set_inference_key_value_memory=False, inference_max_sequence_len=None): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size Loading @@ -246,26 +246,22 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model( tokens, position_ids, attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output) if get_key_value: output_tensor, layer_past = output_tensor set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) send_forward(output_tensor) args.seq_length = orig_seq_length if get_key_value: return output_tensor, layer_past return output_tensor def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokens_to_generate, all_probs=False, type_ids=None): tokens_to_generate, all_probs=False, type_ids=None, temperature=None): args = get_args() tokenizer = get_tokenizer() Loading @@ -282,7 +278,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, counter = 0 layer_past = None batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens Loading @@ -299,11 +294,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, while context_length < maxlen: types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] if type_ids is not None: types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view( batch_size, -1) positions2use = position_ids[:, context_length - 1].view( Loading @@ -311,29 +310,35 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) output, layer_past = forward_step(model, tokens2use, output = forward_step( model, tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use, forward_method_parallel_output=False) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=maxlen, tokentype_ids=types2use) if mpu.is_pipeline_last_stage(): assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits /= temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the out of vocabulary tokens. tokenizer = get_tokenizer() prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch( tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens Loading