Loading megatron/utils.py +13 −1 Original line number Diff line number Diff line Loading @@ -338,7 +338,19 @@ def load_checkpoint(model, optimizer, lr_scheduler, args): torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. try: sd = torch.load(checkpoint_name, map_location='cpu') except ModuleNotFoundError: # For backward compatibility. print_rank_0(' > deserializing using the old code structure ...') import sys sys.modules['fp16.loss_scaler'] = sys.modules[ 'megatron.fp16.loss_scaler'] sd = torch.load(checkpoint_name, map_location='cpu') sys.modules.pop('fp16.loss_scaler', None) except: print_rank_0('could not load the checkpoint') exit() # Iterations. if args.finetune or release: Loading Loading
megatron/utils.py +13 −1 Original line number Diff line number Diff line Loading @@ -338,7 +338,19 @@ def load_checkpoint(model, optimizer, lr_scheduler, args): torch.distributed.get_rank(), checkpoint_name)) # Load the checkpoint. try: sd = torch.load(checkpoint_name, map_location='cpu') except ModuleNotFoundError: # For backward compatibility. print_rank_0(' > deserializing using the old code structure ...') import sys sys.modules['fp16.loss_scaler'] = sys.modules[ 'megatron.fp16.loss_scaler'] sd = torch.load(checkpoint_name, map_location='cpu') sys.modules.pop('fp16.loss_scaler', None) except: print_rank_0('could not load the checkpoint') exit() # Iterations. if args.finetune or release: Loading