Loading megatron/model/transformer.py +14 −10 Original line number Diff line number Diff line Loading @@ -172,14 +172,14 @@ class ParallelSelfAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_index): def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): input_shape = mixed_layer.size(); if num_splits_first: """[s, b, num_splits * np * hn] -->(view) [s, b, num_splits, np, hn] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ input_shape = mixed_layer.size(); if num_splits_index == 0: intermediate_shape = input_shape[:-1] +\ (num_splits, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) Loading @@ -187,7 +187,11 @@ class ParallelSelfAttention(MegatronModule): mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-2, -3).contiguous() else: assert num_splits_index == 2 """[s, b, np * hn * num_splits] -->(view) [s, b, np, hn, num_splits] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ intermediate_shape = input_shape[:-1] +\ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, num_splits) Loading @@ -213,10 +217,10 @@ class ParallelSelfAttention(MegatronModule): if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 0) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) elif checkpoint_version == 1: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 2) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ Loading Loading
megatron/model/transformer.py +14 −10 Original line number Diff line number Diff line Loading @@ -172,14 +172,14 @@ class ParallelSelfAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_index): def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): input_shape = mixed_layer.size(); if num_splits_first: """[s, b, num_splits * np * hn] -->(view) [s, b, num_splits, np, hn] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ input_shape = mixed_layer.size(); if num_splits_index == 0: intermediate_shape = input_shape[:-1] +\ (num_splits, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) Loading @@ -187,7 +187,11 @@ class ParallelSelfAttention(MegatronModule): mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-2, -3).contiguous() else: assert num_splits_index == 2 """[s, b, np * hn * num_splits] -->(view) [s, b, np, hn, num_splits] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ intermediate_shape = input_shape[:-1] +\ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, num_splits) Loading @@ -213,10 +217,10 @@ class ParallelSelfAttention(MegatronModule): if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 0) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) elif checkpoint_version == 1: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 2) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ Loading