Loading megatron/checkpointing.py +1 −1 Original line number Diff line number Diff line Loading @@ -312,7 +312,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") sys.exit() param.data.copy_(fixed_param) if name.endswith(('.key_value.weight', '.key_value.bias'): if name.endswith(('.key_value.weight', '.key_value.bias')): if checkpoint_version == 0: fixed_param = _transpose_first_dim(param.data, 2, True, model) elif checkpoint_version == 1.0: Loading Loading
megatron/checkpointing.py +1 −1 Original line number Diff line number Diff line Loading @@ -312,7 +312,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") sys.exit() param.data.copy_(fixed_param) if name.endswith(('.key_value.weight', '.key_value.bias'): if name.endswith(('.key_value.weight', '.key_value.bias')): if checkpoint_version == 0: fixed_param = _transpose_first_dim(param.data, 2, True, model) elif checkpoint_version == 1.0: Loading