Loading megatron/model/gpt_model.py +6 −15 Original line number Diff line number Diff line Loading @@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal def post_language_model_processing(lm_output, labels, logit_weights, get_key_value, parallel_output, forward_method_parallel_output, parallel_output, fp16_lm_cross_entropy): if get_key_value: lm_output, presents = lm_output # Output. if forward_method_parallel_output is not None: parallel_output = forward_method_parallel_output output = parallel_lm_logits( lm_output, logit_weights, parallel_output) if get_key_value: output = [output, presents] if labels is None: return output else: Loading Loading @@ -90,23 +82,22 @@ class GPTModel(MegatronModule): self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): lm_output = self.language_model( input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) if self.post_process: return post_language_model_processing( lm_output, labels, self.word_embeddings_weight(), get_key_value, self.parallel_output, forward_method_parallel_output, self.fp16_lm_cross_entropy) else: return lm_output Loading megatron/model/language_model.py +16 −12 Original line number Diff line number Diff line Loading @@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule): def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, get_key_value=False, pooling_sequence_index=0, enc_dec_attn_mask=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): # Embeddings. Loading @@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule): # encoder. if enc_hidden_states is None: encoder_output = self.encoder(encoder_input, encoder_output = self.encoder( encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) else: encoder_output = enc_hidden_states.to(encoder_input.dtype) Loading @@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule): dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids) # decoder decoder_output = self.decoder(dec_embedding_output, decoder_output = self.decoder( dec_embedding_output, dec_attn_mask, layer_past=layer_past, get_key_value=get_key_value, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask) enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output Loading megatron/model/transformer.py +87 −70 Original line number Diff line number Diff line Loading @@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule): self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype projection_size = args.kv_channels * args.num_attention_heads Loading Loading @@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None self.inference_current_sequence_len = 0 def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward(self, hidden_states, attention_mask, encoder_output=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_value_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_current_sequence_len = 0 # Some consistency check. if inference_max_sequence_len: assert self.inference_current_sequence_len < \ self.inference_key_memory.size(0) assert inference_max_sequence_len == \ self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len # is not provided, make sure there is no potential memory left # from previous inference. if not inference_max_sequence_len: self.inference_key_memory = None self.inference_value_memory = None # ===================== # Query, Key, and Value # ===================== Loading Loading @@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule): self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================================== # Adjust key, value, and attention mask for inference # =================================================== if inference_max_sequence_len: # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) end = self.inference_current_sequence_len # Copy key and values. self.inference_key_memory[start:end, ...] = key_layer self.inference_value_memory[start:end, ...] = value_layer key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask attention_mask = attention_mask[..., start:end, :end] # =================================== # Raw attention scores. [b, np, s, s] Loading Loading @@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule): # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout Loading Loading @@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule): output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias Loading Loading @@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule): output_layer_init_method) def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False): encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention(layernorm_output, self.self_attention( layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: attention_output, presents = attention_output set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) # Residual connection. if self.apply_residual_connection_post_layernorm: Loading Loading @@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule): residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output Loading Loading @@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule): forward_step_func""" self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # Checks. if layer_past is not None: assert get_key_value, \ 'for not None values in layer_past, ' \ 'expected get_key_value to be set' if get_key_value: if inference_max_sequence_len: assert self.activations_checkpoint_method is None, \ 'get_key_value does not work with ' \ 'activation checkpointing' 'inference does not work with activation checkpointing' if self.pre_process: # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. Loading @@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule): encoder_output, enc_dec_attn_mask) else: if get_key_value: presents = [] for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[index] hidden_states = layer(hidden_states, hidden_states = layer( hidden_states, attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, layer_past=past, get_key_value=get_key_value) if get_key_value: hidden_states, present = hidden_states presents.append(present) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) # Final layer norm. if self.post_process: Loading @@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule): output = self.final_layernorm(hidden_states) else: output = hidden_states if get_key_value: output = [output, presents] return output megatron/text_generation_utils.py +26 −21 Original line number Diff line number Diff line Loading @@ -227,8 +227,8 @@ def switch(val1, val2, boolean): def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, layer_past=None, get_key_value=None, forward_method_parallel_output=None): set_inference_key_value_memory=False, inference_max_sequence_len=None): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size Loading @@ -243,20 +243,16 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model( tokens, position_ids, attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output) if get_key_value: output_tensor, layer_past = output_tensor set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) send_forward(output_tensor) args.seq_length = orig_seq_length if get_key_value: return output_tensor, layer_past return output_tensor Loading @@ -279,7 +275,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, counter = 0 layer_past = None batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens Loading @@ -296,11 +291,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, while context_length < maxlen: types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] if type_ids is not None: types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view( batch_size, -1) positions2use = position_ids[:, context_length - 1].view( Loading @@ -308,18 +307,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) output, layer_past = forward_step(model, tokens2use, output = forward_step( model, tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use, forward_method_parallel_output=False) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=maxlen, tokentype_ids=types2use) if mpu.is_pipeline_last_stage(): assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: Loading @@ -331,6 +332,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths, prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the out of vocabulary tokens. tokenizer = get_tokenizer() prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch( tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens Loading megatron/training.py +18 −16 Original line number Diff line number Diff line Loading @@ -189,7 +189,7 @@ def update_train_iters(args): print_rank_0('setting training iterations to {}'.format(args.train_iters)) def get_model(model_provider_func): def get_model(model_provider_func, wrap_with_ddp=True): """Build the model.""" args = get_args() Loading Loading @@ -243,22 +243,24 @@ def get_model(model_provider_func): if args.fp16 or args.bf16: model = [Float16Module(model_module, args) for model_module in model] if wrap_with_ddp: if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = [torchDDP(model_module, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) for model_module in model] return model if args.DDP_impl == 'local': elif args.DDP_impl == 'local': model = [LocalDDP(model_module, args.accumulate_allreduce_grads_in_fp32, args.use_contiguous_buffers_in_local_ddp) for model_module in model] return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) else: raise NotImplementedError('Unknown DDP implementation specified: ' '{}. Exiting.'.format(args.DDP_impl)) return model def get_learning_rate_scheduler(optimizer): Loading Loading
megatron/model/gpt_model.py +6 −15 Original line number Diff line number Diff line Loading @@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal def post_language_model_processing(lm_output, labels, logit_weights, get_key_value, parallel_output, forward_method_parallel_output, parallel_output, fp16_lm_cross_entropy): if get_key_value: lm_output, presents = lm_output # Output. if forward_method_parallel_output is not None: parallel_output = forward_method_parallel_output output = parallel_lm_logits( lm_output, logit_weights, parallel_output) if get_key_value: output = [output, presents] if labels is None: return output else: Loading Loading @@ -90,23 +82,22 @@ class GPTModel(MegatronModule): self.language_model.set_input_tensor(input_tensor) def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None): tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): lm_output = self.language_model( input_ids, position_ids, attention_mask, layer_past=layer_past, get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) if self.post_process: return post_language_model_processing( lm_output, labels, self.word_embeddings_weight(), get_key_value, self.parallel_output, forward_method_parallel_output, self.fp16_lm_cross_entropy) else: return lm_output Loading
megatron/model/language_model.py +16 −12 Original line number Diff line number Diff line Loading @@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule): def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, get_key_value=False, pooling_sequence_index=0, enc_dec_attn_mask=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): # Embeddings. Loading @@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule): # encoder. if enc_hidden_states is None: encoder_output = self.encoder(encoder_input, encoder_output = self.encoder( encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) else: encoder_output = enc_hidden_states.to(encoder_input.dtype) Loading @@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule): dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids) # decoder decoder_output = self.decoder(dec_embedding_output, decoder_output = self.decoder( dec_embedding_output, dec_attn_mask, layer_past=layer_past, get_key_value=get_key_value, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask) enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output Loading
megatron/model/transformer.py +87 −70 Original line number Diff line number Diff line Loading @@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule): self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype projection_size = args.kv_channels * args.num_attention_heads Loading Loading @@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None self.inference_current_sequence_len = 0 def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward(self, hidden_states, attention_mask, encoder_output=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_value_memory = self._allocate_memory( inference_max_sequence_len, hidden_states.size(1)) self.inference_current_sequence_len = 0 # Some consistency check. if inference_max_sequence_len: assert self.inference_current_sequence_len < \ self.inference_key_memory.size(0) assert inference_max_sequence_len == \ self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len # is not provided, make sure there is no potential memory left # from previous inference. if not inference_max_sequence_len: self.inference_key_memory = None self.inference_value_memory = None # ===================== # Query, Key, and Value # ===================== Loading Loading @@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule): self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) # ================================== # Adjust key and value for inference # ================================== if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) # =================================================== # Adjust key, value, and attention mask for inference # =================================================== if inference_max_sequence_len: # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) end = self.inference_current_sequence_len # Copy key and values. self.inference_key_memory[start:end, ...] = key_layer self.inference_value_memory[start:end, ...] = value_layer key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask attention_mask = attention_mask[..., start:end, :end] # =================================== # Raw attention scores. [b, np, s, s] Loading Loading @@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule): # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] # =========================== # Attention probs and dropout Loading Loading @@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule): output, bias = self.dense(context_layer) if get_key_value: output = [output, present] return output, bias Loading Loading @@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule): output_layer_init_method) def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, layer_past=None, get_key_value=False): encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention(layernorm_output, self.self_attention( layernorm_output, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if get_key_value: attention_output, presents = attention_output set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) # Residual connection. if self.apply_residual_connection_post_layernorm: Loading Loading @@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule): residual, self.hidden_dropout) if get_key_value: output = [output, presents] return output Loading Loading @@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule): forward_step_func""" self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None): # Checks. if layer_past is not None: assert get_key_value, \ 'for not None values in layer_past, ' \ 'expected get_key_value to be set' if get_key_value: if inference_max_sequence_len: assert self.activations_checkpoint_method is None, \ 'get_key_value does not work with ' \ 'activation checkpointing' 'inference does not work with activation checkpointing' if self.pre_process: # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. Loading @@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule): encoder_output, enc_dec_attn_mask) else: if get_key_value: presents = [] for index in range(self.num_layers): layer = self._get_layer(index) past = None if layer_past is not None: past = layer_past[index] hidden_states = layer(hidden_states, hidden_states = layer( hidden_states, attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, layer_past=past, get_key_value=get_key_value) if get_key_value: hidden_states, present = hidden_states presents.append(present) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) # Final layer norm. if self.post_process: Loading @@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule): output = self.final_layernorm(hidden_states) else: output = hidden_states if get_key_value: output = [output, presents] return output
megatron/text_generation_utils.py +26 −21 Original line number Diff line number Diff line Loading @@ -227,8 +227,8 @@ def switch(val1, val2, boolean): def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, layer_past=None, get_key_value=None, forward_method_parallel_output=None): set_inference_key_value_memory=False, inference_max_sequence_len=None): # Hidden size changes when not using recompute, need to tell p2p_communicate # functions the correct size Loading @@ -243,20 +243,16 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model( tokens, position_ids, attention_mask, tokentype_ids=tokentype_ids, layer_past=layer_past, get_key_value=get_key_value, forward_method_parallel_output=forward_method_parallel_output) if get_key_value: output_tensor, layer_past = output_tensor set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len) send_forward(output_tensor) args.seq_length = orig_seq_length if get_key_value: return output_tensor, layer_past return output_tensor Loading @@ -279,7 +275,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths, counter = 0 layer_past = None batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens Loading @@ -296,11 +291,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, while context_length < maxlen: types2use = None if counter == 0: # Allocate memory for the entire context. set_inference_key_value_memory = True tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] if type_ids is not None: types2use = type_ids[:, :context_length] else: # Set this to false so the memory is not reallocated. set_inference_key_value_memory = False tokens2use = tokens[:, context_length - 1].view( batch_size, -1) positions2use = position_ids[:, context_length - 1].view( Loading @@ -308,18 +307,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths, if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) output, layer_past = forward_step(model, tokens2use, output = forward_step( model, tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use, forward_method_parallel_output=False) set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=maxlen, tokentype_ids=types2use) if mpu.is_pipeline_last_stage(): assert output is not None output = output.float() logits = output[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: Loading @@ -331,6 +332,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths, prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length # Clamp the out of vocabulary tokens. tokenizer = get_tokenizer() prev = torch.clamp(prev, max=tokenizer.vocab_size - 1) new_tokens = switch( tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens Loading
megatron/training.py +18 −16 Original line number Diff line number Diff line Loading @@ -189,7 +189,7 @@ def update_train_iters(args): print_rank_0('setting training iterations to {}'.format(args.train_iters)) def get_model(model_provider_func): def get_model(model_provider_func, wrap_with_ddp=True): """Build the model.""" args = get_args() Loading Loading @@ -243,22 +243,24 @@ def get_model(model_provider_func): if args.fp16 or args.bf16: model = [Float16Module(model_module, args) for model_module in model] if wrap_with_ddp: if args.DDP_impl == 'torch': i = torch.cuda.current_device() model = [torchDDP(model_module, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group()) for model_module in model] return model if args.DDP_impl == 'local': elif args.DDP_impl == 'local': model = [LocalDDP(model_module, args.accumulate_allreduce_grads_in_fp32, args.use_contiguous_buffers_in_local_ddp) for model_module in model] return model raise NotImplementedError('Unknown DDP implementation specified: {}. ' 'Exiting.'.format(args.DDP_impl)) else: raise NotImplementedError('Unknown DDP implementation specified: ' '{}. Exiting.'.format(args.DDP_impl)) return model def get_learning_rate_scheduler(optimizer): Loading