Commit ee76a501 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'v0_checkpoint_fixes' into 'main'

fixed compatiblity with v0 checkpoints

See merge request ADLR/megatron-lm!268
parents 2ff004ac 26b49aab
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -59,7 +59,7 @@ We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nv
Alternatively, you can directly download the checkpoints using:

<pre>
BERT-345M-uncased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_cased/zip -O megatron_bert_345m_v0.1_uncased.zip
BERT-345M-uncased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_uncased/zip -O megatron_bert_345m_v0.1_uncased.zip
BERT-345M-cased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_cased/zip -O megatron_bert_345m_v0.1_cased.zip
GPT-345M: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip
</pre>
+5 −1
Original line number Diff line number Diff line
@@ -211,6 +211,9 @@ def fix_query_key_value_ordering(model, checkpoint_version):
    version is smaller than 2.0
    """
    if checkpoint_version < 2.0:
        if isinstance(model, list):
            assert len(model)==1
            model = model[0]
        for name, param in model.named_parameters():
            if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
                if checkpoint_version == 0:
@@ -291,8 +294,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
        state_dict = torch.load(checkpoint_name, map_location='cpu')
        sys.modules.pop('fp16.loss_scaler', None)
        sys.modules.pop('megatron.fp16.loss_scaler', None)
    except BaseException:
    except BaseException as e:
        print_rank_0('could not load the checkpoint')
        print_rank_0(e)
        sys.exit()

    # set checkpoint version