Loading megatron/checkpointing.py +14 −8 Original line number Diff line number Diff line Loading @@ -110,6 +110,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Only rank zero of the data parallel writes to the disk. if isinstance(model, torchDDP): model = model.module if torch.distributed.get_rank() == 0: print('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save), flush=True) if mpu.get_data_parallel_rank() == 0: # Arguments, iteration, and model. Loading Loading @@ -137,14 +142,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) print('global rank {} is saving checkpoint at iteration {:7d} to {}'. format(torch.distributed.get_rank(), iteration, checkpoint_name)) ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) print(' successfully saved {}'.format(checkpoint_name)) # Wait so everyone is done (necessary) torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' successfully saved checkpoint at iteration {:7d} to {}'.format( iteration, args.save), flush=True) # And update the latest iteration if torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) Loading Loading @@ -192,9 +197,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): # Checkpoint. checkpoint_name = get_checkpoint_name(load_dir, iteration, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) if torch.distributed.get_rank() == 0: print(' loading checkpoint from {} at iteration {}'.format( args.load, iteration), flush=True) # Load the checkpoint. try: Loading Loading @@ -276,8 +281,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): sys.exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) if torch.distributed.get_rank() == 0: print(' successfully loaded checkpoint from {} at iteration {}'.format( args.load, iteration), flush=True) return iteration Loading Loading
megatron/checkpointing.py +14 −8 Original line number Diff line number Diff line Loading @@ -110,6 +110,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Only rank zero of the data parallel writes to the disk. if isinstance(model, torchDDP): model = model.module if torch.distributed.get_rank() == 0: print('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save), flush=True) if mpu.get_data_parallel_rank() == 0: # Arguments, iteration, and model. Loading Loading @@ -137,14 +142,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) print('global rank {} is saving checkpoint at iteration {:7d} to {}'. format(torch.distributed.get_rank(), iteration, checkpoint_name)) ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) print(' successfully saved {}'.format(checkpoint_name)) # Wait so everyone is done (necessary) torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(' successfully saved checkpoint at iteration {:7d} to {}'.format( iteration, args.save), flush=True) # And update the latest iteration if torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) Loading Loading @@ -192,9 +197,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): # Checkpoint. checkpoint_name = get_checkpoint_name(load_dir, iteration, release) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) if torch.distributed.get_rank() == 0: print(' loading checkpoint from {} at iteration {}'.format( args.load, iteration), flush=True) # Load the checkpoint. try: Loading Loading @@ -276,8 +281,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): sys.exit() torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) if torch.distributed.get_rank() == 0: print(' successfully loaded checkpoint from {} at iteration {}'.format( args.load, iteration), flush=True) return iteration Loading