Commit ae82905a authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

tweaks to ynet unsupervised inner loop

parent 67c84dc1
Loading
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -132,7 +132,10 @@ def get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=1):
    reg_loss = calculate_loss_regressor(psi_out_mod, tf.reduce_mean(psi_out_true, axis=[1], keepdims=True), 
                    params, hyper_params, weight=weight)
    reg_loss = tf.cast(reg_loss, tf.float32)
    return reg_loss
    regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    reg_total_loss = tf.add_n([reg_loss] + regularization, name='total_loss')
    reg_totat_loss = tf.cast(reg_total_loss, tf.float32)
    return reg_total_loss

def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None):
    input = tf.cast(tf.reshape(n_net.model_output,[batch_size, -1]), tf.float32)
+31 −42
Original line number Diff line number Diff line
@@ -548,9 +548,8 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
            # stop_op = tf.stop_gradient(n_net.model_output['encoder'])

            # calculate the total loss
            # psi_out_true = tf.placeholder(tf.float32, shape=images.shape.as_list(), name="psi_out_true")
            psi_out_true = images
            constr_loss = losses.get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=10)
            constr_loss = losses.get_YNet_constraint(n_net, hyper_params, params, images, weight=10)
            total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary)

            #get summaries, except for the one produced by string_input_producer
@@ -560,48 +559,50 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
        #######################################
        # Apply Gradients and setup train op #
        #######################################

        # get learning policy
        def learning_policy_func(step):
            return lr_policies.decay_warmup(params, hyper_params, step)
            ## TODO: implement other policies in lr_policies

        # optimizer for unsupervised step
        var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' in str(itm.name)] 
        reg_hyper = deepcopy(hyper_params)
        reg_hyper['initial_learning_rate'] = 1e-1
        def learning_policy_func_reg(step):
            return lr_policies.decay_warmup(params, reg_hyper, step)
        iter_size = params.get('accumulate_step', 0)
        skip_update_cond = tf.cast(tf.floormod(global_step, tf.constant(iter_size, dtype=tf.int32)), tf.bool)

        if params['IMAGE_FP16']:
            opt_type='mixed'
        else:
            opt_type=tf.float32
        # setup optimizer

        reg_opt, learning_rate = optimizers.optimize_loss(constr_loss, 'Momentum', 
                                {'momentum': 0.9}, learning_policy_func_reg, var_list=var_list, run_params=params, hyper_params=reg_hyper, iter_size=iter_size, dtype=opt_type, 
                                loss_scaling=1.0, 
                                skip_update_cond=skip_update_cond,
                                on_horovod=True, model_scopes=None)

        # optimizer for supervised step 
        def learning_policy_func(step):
            return lr_policies.decay_warmup(params, hyper_params, step)
            ## TODO: implement other policies in lr_policies

        opt_dict = hyper_params['optimization']['params']  
        train_opt, learning_rate = optimizers.optimize_loss(total_loss, hyper_params['optimization']['name'], 
                                opt_dict, learning_policy_func, run_params=params, hyper_params=hyper_params, iter_size=iter_size, dtype=opt_type, 
                                loss_scaling=hyper_params.get('loss_scaling',1.0), 
                                skip_update_cond=skip_update_cond,
                                on_horovod=True, model_scopes=n_net.scopes)  
        # optimizer for regularization step
        # var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)] 
        var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)] 
        # var_list = None
        # print_rank(var_list)
        opt = tf.train.MomentumOptimizer(1e-5, 0.9)
        reg_opt = opt.minimize(constr_loss, var_list=var_list)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    # with tf.control_dependencies([tf.group(*[stop_op, reg_opt, update_ops])]):
    # Gather unsupervised training ops 
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step)
    increment_op = tf.assign_add(global_step, 1)
    with tf.control_dependencies([tf.group(*[reg_opt, update_ops])]):
        reg_train = tf.no_op(name='reg_train')
         reg_op = ema.apply(var_list=var_list)

    # Gather all training related ops into a single one.
    # Gather supervised training related ops into a single one.
    increment_op = tf.assign_add(global_step, 1)
    ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step)
    all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op]))

    with tf.control_dependencies([all_ops]):
            train_op = ema.apply(tf.trainable_variables()) 
            # train_op = tf.no_op(name='train')
    
    ########################
    # Setting up Summaries #
@@ -668,7 +669,7 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
    logFreq   = params[ 'log_frequency' ]
    traceStep = params[ 'trace_step' ]
    maxTime = params.get('max_time', 1e12)
    inner_loop = params.get('inner_iter', 1e12)
    inner_loop = hyper_params.get('inner_iter', 1e12)
    
    val_results = []
    loss_results = []
@@ -732,27 +733,15 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
            return val_results, loss_results
        if np.isnan(loss_value):
            break
        # if doLog:
            # constr_val = sess.run(constr_loss, feed_dict={psi_out_true:current_batch})
            # print_rank('\t\tstep={}, current constr_loss={:2.3e}'.format(train_elf.last_step, constr_val))
        # current_batch_list = []
        if inner_loop < 100:
            batch_buffer.append(current_batch)
        # print_rank(len(batch_buffer))
        if bool(train_elf.last_step % inner_loop == 0 and train_elf.last_step >= 10):
            for itr, current_batch in enumerate(batch_buffer):
                # noise = np.random.random(images.shape.as_list()[1:]) 
                # noise = noise.astype(np.float32)
                # mix = 0.25
                # current_batch = (1 - mix) * current_batch + mix * noise[np.newaxis] 
                _, constr_val = sess.run([reg_train, constr_loss], feed_dict={psi_out_true:current_batch})
                # if doLog:
                _, constr_val = sess.run([reg_op, constr_loss], feed_dict={psi_out_true:current_batch})
                if doLog:
                    print_rank('\t\tstep={}, reg iter={}, constr_loss={:2.3e}'.format(train_elf.last_step, itr, constr_val))
                    # print('\t\trank={}, step={}, reg iter={}, constr_loss={:2.3e}'.format(hvd.rank(), train_elf.last_step, itr, constr_val))
            del batch_buffer
            batch_buffer = []
            # for i in range(len(IO_ops)):
            #     sess.run(IO_ops[:i + 1])
            
    val_results.append((train_elf.last_step,val))
    tf.reset_default_graph()