Loading megatron/checkpointing.py +1 −1 Original line number Diff line number Diff line Loading @@ -101,7 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 1.0 state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration state_dict['model'] = model.state_dict_for_save_checkpoint() Loading megatron/model/transformer.py +22 −10 Original line number Diff line number Diff line Loading @@ -172,19 +172,28 @@ class ParallelSelfAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def _transpose_last_dim(self, mixed_layer, num_splits): def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_index): """[s, b, num_splits * np * hn] -->(view) [s, b, num_splits, np, hn] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ input_shape = mixed_layer.size(); if num_splits_index == 0: intermediate_shape = input_shape[:-1] +\ (num_splits, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-2, -3).contiguous() else: assert num_splits_index == 2 intermediate_shape = input_shape[:-1] +\ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, num_splits) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-1, -2).contiguous() mixed_layer = mixed_layer.view(*input_shape) return mixed_layer Loading @@ -201,10 +210,13 @@ class ParallelSelfAttention(MegatronModule): mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None and \ checkpoint_version == 0: if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 0) elif checkpoint_version == 1: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 2) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ Loading Loading
megatron/checkpointing.py +1 −1 Original line number Diff line number Diff line Loading @@ -101,7 +101,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 1.0 state_dict['checkpoint_version'] = 2.0 state_dict['iteration'] = iteration state_dict['model'] = model.state_dict_for_save_checkpoint() Loading
megatron/model/transformer.py +22 −10 Original line number Diff line number Diff line Loading @@ -172,19 +172,28 @@ class ParallelSelfAttention(MegatronModule): init_method=output_layer_init_method, skip_bias_add=True) def _transpose_last_dim(self, mixed_layer, num_splits): def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_index): """[s, b, num_splits * np * hn] -->(view) [s, b, num_splits, np, hn] -->(tranpose) [s, b, np, num_splits, hn] -->(view) [s, b, np * num_splits * hn] """ input_shape = mixed_layer.size(); if num_splits_index == 0: intermediate_shape = input_shape[:-1] +\ (num_splits, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-2, -3).contiguous() else: assert num_splits_index == 2 intermediate_shape = input_shape[:-1] +\ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, num_splits) mixed_layer = mixed_layer.view(*intermediate_shape) mixed_layer = mixed_layer.transpose(-1, -2).contiguous() mixed_layer = mixed_layer.view(*input_shape) return mixed_layer Loading @@ -201,10 +210,13 @@ class ParallelSelfAttention(MegatronModule): mixed_x_layer, _ = self.query_key_value(hidden_states) checkpoint_version = get_checkpoint_version() if checkpoint_version is not None and \ checkpoint_version == 0: if checkpoint_version is not None: if checkpoint_version == 0: # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3) mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 0) elif checkpoint_version == 1: # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, 2) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ Loading