Loading megatron/checkpointing.py +1 −1 Original line number Diff line number Diff line Loading @@ -101,7 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 1.0 state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration state_dict['model'] = model.state_dict_for_save_checkpoint() Loading megatron/model/transformer.py +41 −25 Original line number Diff line number Diff line Loading @@ -172,17 +172,30 @@ class ParallelSelfAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def _transpose_last_dim(self, mixed_layer): """[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose) [s, b, hp, 3] -->(view) [s, b, 3 * hp] """ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): input_shape = mixed_layer.size(); last_dim = input_shape[-1] assert last_dim % 3 == 0, "expected QKV dimension" last_dim_split = last_dim // 3 if num_splits_first: """[s, b, num_splits * np * hn] -->(view) [s, b, num_splits, np, hn] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ intermediate_shape = input_shape[:-1] +\ (3, last_dim_split) (num_splits, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-2, -3).contiguous() else: """[s, b, np * hn * num_splits] -->(view) [s, b, np, hn, num_splits] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ intermediate_shape = input_shape[:-1] +\ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, num_splits) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-1, -2).contiguous() mixed_layer = mixed_layer.view(*input_shape) Loading @@ -197,25 +210,28 @@ class ParallelSelfAttention(MegatronModule): # Query, Key, and Value # ===================== # Attention heads [sq, b, hp] --> [sq, b, hp * 3] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None and \ checkpoint_version == 0: # [sq, b, 3 * hp] --> [sq, b, hp * 3] mixed_x_layer = self._transpose_last_dim(mixed_x_layer) # [sq, b, hp * 3] --> [sq, b, np, hn, 3] if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) elif checkpoint_version == 1.0: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, 3) 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, hn, 3] --> 3 [sq, b, np, hn] query_layer = mixed_x_layer[:,:,:,:,0] key_layer = mixed_x_layer[:,:,:,:,1] value_layer = mixed_x_layer[:,:,:,:,2] # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) # ================================== # Adjust key and value for inference Loading Loading
megatron/checkpointing.py +1 −1 Original line number Diff line number Diff line Loading @@ -101,7 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 1.0 state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration state_dict['model'] = model.state_dict_for_save_checkpoint() Loading
megatron/model/transformer.py +41 −25 Original line number Diff line number Diff line Loading @@ -172,17 +172,30 @@ class ParallelSelfAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def _transpose_last_dim(self, mixed_layer): """[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose) [s, b, hp, 3] -->(view) [s, b, 3 * hp] """ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): input_shape = mixed_layer.size(); last_dim = input_shape[-1] assert last_dim % 3 == 0, "expected QKV dimension" last_dim_split = last_dim // 3 if num_splits_first: """[s, b, num_splits * np * hn] -->(view) [s, b, num_splits, np, hn] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ intermediate_shape = input_shape[:-1] +\ (3, last_dim_split) (num_splits, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-2, -3).contiguous() else: """[s, b, np * hn * num_splits] -->(view) [s, b, np, hn, num_splits] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ intermediate_shape = input_shape[:-1] +\ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, num_splits) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-1, -2).contiguous() mixed_layer = mixed_layer.view(*input_shape) Loading @@ -197,25 +210,28 @@ class ParallelSelfAttention(MegatronModule): # Query, Key, and Value # ===================== # Attention heads [sq, b, hp] --> [sq, b, hp * 3] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None and \ checkpoint_version == 0: # [sq, b, 3 * hp] --> [sq, b, hp * 3] mixed_x_layer = self._transpose_last_dim(mixed_x_layer) # [sq, b, hp * 3] --> [sq, b, np, hn, 3] if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) elif checkpoint_version == 1.0: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, 3) 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, hn, 3] --> 3 [sq, b, np, hn] query_layer = mixed_x_layer[:,:,:,:,0] key_layer = mixed_x_layer[:,:,:,:,1] value_layer = mixed_x_layer[:,:,:,:,2] # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) # ================================== # Adjust key and value for inference Loading