Commit 45504541 authored by mohammad's avatar mohammad
Browse files

Merge branch 'blendable_dataset' into refactor_learning_rate

parents ff12df6b 98989693
Loading
Loading
Loading
Loading
+9 −4
Original line number Diff line number Diff line
@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
        print('-------------------- arguments --------------------', flush=True)
        print('------------------------ arguments ------------------------',
              flush=True)
        str_list = []
        for arg in vars(args):
            dots = '.' * (32 - len(arg))
            dots = '.' * (48 - len(arg))
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
        print('---------------- end of arguments ----------------', flush=True)
        print('-------------------- end of arguments ---------------------',
              flush=True)


def _check_arg_is_not_none(args, arg):
@@ -401,7 +403,10 @@ def _add_data_args(parser):
    group = parser.add_argument_group(title='data and dataloader')

    group.add_argument('--data-path', nargs='*', default=None,
                       help='Path to combined dataset to split.')
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
    group.add_argument('--split', type=str, default='969, 30, 1',
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
+5 −13
Original line number Diff line number Diff line
@@ -89,8 +89,7 @@ def get_checkpoint_tracker_filename(checkpoints_path):
    return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')


def save_checkpoint(iteration, model, optimizer, lr_scheduler,
                    consumed_train_samples=None, consumed_valid_samples=None):
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint."""
    args = get_args()

@@ -104,10 +103,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler,
        state_dict['args'] = args
        state_dict['checkpoint_version'] = 2.0
        state_dict['iteration'] = iteration
        if consumed_train_samples:
            state_dict['consumed_train_samples'] = consumed_train_samples
        if consumed_valid_samples:
            state_dict['consumed_valid_samples'] = consumed_valid_samples
        state_dict['model'] = model.state_dict_for_save_checkpoint()

        # Optimizer stuff.
@@ -219,17 +214,14 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
                                 checkpoint_name))
                sys.exit()

    if 'consumed_train_samples' in state_dict:
    # Check arguments.
    assert args.consumed_train_samples == 0
        args.consumed_train_samples = state_dict['consumed_train_samples']
    if 'consumed_valid_samples' in state_dict:
    assert args.consumed_valid_samples == 0
        args.consumed_valid_samples = state_dict['consumed_valid_samples']

    # Check arguments.
    if 'args' in state_dict:
        checkpoint_args = state_dict['args']
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0)
        args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0)
    else:
        print_rank_0('could not find arguments in the checkpoint ...')

+2 −2
Original line number Diff line number Diff line
@@ -60,7 +60,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
  for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {

    // Determine where the max error in sampling is happening.
    double sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
    auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
    int64_t max_error_index = 0;
    double max_error = weights_ptr[0] * sample_idx_double -
      static_cast<double>(current_samples[0]);
@@ -86,7 +86,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
  if (verbose) {
    std::cout << " > sample ratios:" << std::endl;
    for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
      double ratio = static_cast<double>(current_samples[dataset_idx]) /
      auto ratio = static_cast<double>(current_samples[dataset_idx]) /
	static_cast<double>(size);
      std::cout << "   dataset " << dataset_idx << ", input: " <<
	weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; 
+2 −6
Original line number Diff line number Diff line
@@ -104,9 +104,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
                                   iteration, False)

    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler,
                        consumed_train_samples=args.consumed_train_samples,
                        consumed_valid_samples=args.consumed_valid_samples)
        save_checkpoint(iteration, model, optimizer, lr_scheduler)

    if args.do_test:
        # Run on test data.
@@ -438,9 +436,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
        # Checkpointing
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint(iteration, model, optimizer, lr_scheduler,
                            consumed_train_samples=args.consumed_train_samples,
                            consumed_valid_samples=args.consumed_valid_samples)
            save_checkpoint(iteration, model, optimizer, lr_scheduler)

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \