Commit ee38e7f9 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

fixed deserializing issue with old checkpoint

parent 9993ea25
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -338,7 +338,19 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
            torch.distributed.get_rank(), checkpoint_name))

    # Load the checkpoint.
    try:
        sd = torch.load(checkpoint_name, map_location='cpu')
    except ModuleNotFoundError:
        # For backward compatibility.
        print_rank_0(' > deserializing using the old code structure ...')
        import sys
        sys.modules['fp16.loss_scaler'] = sys.modules[
            'megatron.fp16.loss_scaler']
        sd = torch.load(checkpoint_name, map_location='cpu')
        sys.modules.pop('fp16.loss_scaler', None)
    except:
        print_rank_0('could not load the checkpoint')
        exit()

    # Iterations.
    if args.finetune or release: