Commit a3098084 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

modifying inner training loop of YNet

parent 10a69fbd
Loading
Loading
Loading
Loading
Loading
+18 −8
Original line number Diff line number Diff line
@@ -83,13 +83,13 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
        pot_labels, probe_labels_re, probe_labels_im = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
        #weight= np.prod(pot_labels.shape.as_list()[-2:])
        weight=None
        inv_str = hyper_params.get('inv_strength', 0.1)
        reg_str = hyper_params.get('reg_strength', 0.1)
        inv_str = hyper_params.get('inv_strength', 1)
        reg_str = hyper_params.get('reg_strength', 0.01)
        dec_str = hyper_params.get('dec_strength', 1) 
        inverter_loss = inv_str * calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight)
        decoder_loss_im = dec_str * calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight)
        decoder_loss_re = dec_str * calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight)
        psi_out_mod = thin_object(probe_re, probe_im, pot)
        psi_out_mod = thin_object(probe_re, probe_im, pot, summarize=False)
        reg_loss = reg_str * calculate_loss_regressor(psi_out_mod, tf.reduce_mean(images, axis=[1], keepdims=True), 
                    params, hyper_params, weight=weight)
        tf.summary.scalar('reg_loss ', reg_loss)
@@ -111,7 +111,6 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
    #Assemble all of the losses.
    losses = tf.get_collection(tf.GraphKeys.LOSSES)
    if hyper_params['network_type'] == 'YNet':
        reg_str = hyper_params.get('reg_strength', 0.1)
        losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_loss]
        # losses, prefac = ynet_adjusted_losses(losses, step)
    regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
@@ -125,6 +124,16 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
        return total_loss, loss_averages_op, losses 
    return total_loss, loss_averages_op

def get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=1):
    probe_im = tf.cast(n_net.model_output['decoder_IM'], tf.float32)
    probe_re = tf.cast(n_net.model_output['decoder_RE'], tf.float32)
    pot = tf.cast(n_net.model_output['inverter'], tf.float32)
    psi_out_mod = thin_object(probe_re, probe_im, pot)
    reg_loss = calculate_loss_regressor(psi_out_mod, tf.reduce_mean(psi_out_true, axis=[1], keepdims=True), 
                    params, hyper_params, weight=weight)
    reg_loss = tf.cast(reg_loss, tf.float32)
    return reg_loss

def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None):
    input = tf.cast(tf.reshape(n_net.model_output,[batch_size, -1]), tf.float32)
    dim_input = input.shape[1].value
@@ -198,7 +207,7 @@ def calculate_loss_regressor(net_output, labels, params, hyper_params, weight=No
                                    reduction=tf.losses.Reduction.MEAN)
    if loss_params['type'] == 'MSE':
        cost = tf.losses.mean_squared_error(labels, weights=weight, predictions=net_output,
                                            reduction=tf.losses.Reduction.SUM)
                                            reduction=tf.losses.Reduction.MEAN)
    if loss_params['type'] == 'ABS_DIFF':
        cost = tf.losses.absolute_difference(labels, weights=weight, predictions=net_output,
                                            reduction=tf.losses.Reduction.MEAN)
@@ -251,7 +260,7 @@ def fftshift(tensor, tens_format='NCHW'):
    shift_tensor = manip_ops.roll(tensor, shift, dims)
    return shift_tensor

def thin_object(psi_k_re, psi_k_im, potential):
def thin_object(psi_k_re, psi_k_im, potential, summarize=True):
    # mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32)
    # ratio = 0
    # if ratio == 0:
@@ -270,6 +279,7 @@ def thin_object(psi_k_re, psi_k_im, potential):
    psi_out = tf.fft2d(psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list()))
    psi_out_mod = tf.cast(tf.abs(psi_out), tf.float32) ** 2
    psi_out_mod = tf.reduce_mean(psi_out_mod, axis=1, keep_dims=True)
    if summarize:
        tf.summary.image('Psi_k_out', tf.transpose(tf.abs(psi_out_mod)**0.25, perm=[0,2,3,1]), max_outputs=1)
        tf.summary.image('Psi_x_in', tf.transpose(tf.abs(psi_x)**0.25, perm=[0,2,3,1]), max_outputs=1)
    return psi_out_mod 
+3 −1
Original line number Diff line number Diff line
@@ -708,8 +708,10 @@ class ConvNet:
                return tf.nn.tanh(input, name=name)
            elif params['activation'] == 'leaky_relu':
                return tf.nn.leaky_relu(input, name=name)
            else:
            elif params['activation'] == 'relu':
                return tf.nn.relu(input, name=name)
            elif params['activation'] == 'none':
                return input
        else:
            return input

+1 −1
Original line number Diff line number Diff line
@@ -312,7 +312,7 @@ def optimize_loss(loss,

    # Compute gradients.
    grads_and_vars = opt.compute_gradients(
        loss, colocate_gradients_with_ops=True, var_list=var_list
        loss, var_list=var_list
    )

    if dtype == 'mixed' or dtype == tf.float16:
+52 −9
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import numpy as np
import math
from itertools import chain
from multiprocessing import cpu_count
from copy import deepcopy

#TF
import tensorflow as tf
@@ -543,7 +544,13 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
            # Build it and propagate images through it.
            n_net.build_model()

            # # Stop gradients 
            # stop_op = tf.stop_gradient(n_net.model_output['encoder'])

            # calculate the total loss
            # psi_out_true = tf.placeholder(tf.float32, shape=images.shape.as_list(), name="psi_out_true")
            psi_out_true = images
            constr_loss = losses.get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=10)
            total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary)

            #get summaries, except for the one produced by string_input_producer
@@ -573,9 +580,21 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
                                loss_scaling=hyper_params.get('loss_scaling',1.0), 
                                skip_update_cond=skip_update_cond,
                                on_horovod=True, model_scopes=n_net.scopes)  
        # optimizer for regularization step
        # var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)] 
        var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)] 
        # var_list = None
        # print_rank(var_list)
        opt = tf.train.MomentumOptimizer(1e-5, 0.9)
        reg_opt = opt.minimize(constr_loss, var_list=var_list)
    
    # Gather all training related ops into a single one.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    # with tf.control_dependencies([tf.group(*[stop_op, reg_opt, update_ops])]):
    with tf.control_dependencies([tf.group(*[reg_opt, update_ops])]):
        reg_train = tf.no_op(name='reg_train')

    # Gather all training related ops into a single one.
    increment_op = tf.assign_add(global_step, 1)
    ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step)
    all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op]))
@@ -654,7 +673,10 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
    loss_results = []
    loss_value = 1e10
    val = 1e10
    current_batch = np.zeros(images.shape.as_list(), dtype=np.float32)
    batch_buffer = []
    while train_elf.last_step < maxSteps :
        # batch_buffer.append(images.eval(session=sess))
        train_elf.before_run()
        doLog   = bool(train_elf.last_step % logFreq  == 0)
        doSave  = bool(train_elf.last_step % saveStep == 0)
@@ -663,17 +685,17 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
        doValidate = bool(train_elf.last_step % validateStep == 0)
        doFinish = bool(train_elf.start_time - params['start_time'] > maxTime)
        if train_elf.last_step == 1 and params['debug']:
            summary = sess.run([train_op,  summary_merged])[-1]
            _, summary, current_batch = sess.run([train_op,  summary_merged, images])
            train_elf.write_summaries( summary )
        elif not doLog and not doSave and not doTrace and not doSumm:
            sess.run(train_op)
            _, current_batch = sess.run([train_op, images])
        elif doLog and not doSave  and not doSumm:
            _, lr, loss_value, aux_losses = sess.run( [ train_op, learning_rate, total_loss, indv_losses])
            _, lr, loss_value, aux_losses, current_batch = sess.run( [ train_op, learning_rate, total_loss, indv_losses, images])
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, aux_losses, lr)
        elif doLog and doSumm and doSave :
            _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses,
                                                             learning_rate ])
            _, summary, loss_value, aux_losses, lr, current_batch = sess.run( [ train_op, summary_merged, total_loss, indv_losses,
                                                             learning_rate, images ])
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, aux_losses, lr )
            train_elf.write_summaries( summary )
@@ -681,12 +703,12 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
                saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
                print_rank('Saved Checkpoint.')
        elif doLog and doSumm :
            _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ])
            _, summary, loss_value, aux_losses, lr, current_batch = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate, images ])
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, aux_losses, lr )
            train_elf.write_summaries( summary )
        elif doSumm:
            summary = sess.run([train_op,  summary_merged])[-1]
            _, summary, current_batch  = sess.run([train_op,  summary_merged, images])
            train_elf.write_summaries( summary )
        elif doSave :
            if hvd.rank( ) == 0 :
@@ -709,6 +731,27 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
            return val_results, loss_results
        if np.isnan(loss_value):
            break
        # if doLog:
            # constr_val = sess.run(constr_loss, feed_dict={psi_out_true:current_batch})
            # print_rank('\t\tstep={}, current constr_loss={:2.3e}'.format(train_elf.last_step, constr_val))
        # current_batch_list = []
        batch_buffer.append(current_batch)
        # print_rank(len(batch_buffer))
        if bool(train_elf.last_step % 10 == 0 and train_elf.last_step >= 10):
            for itr, current_batch in enumerate(batch_buffer):
                # noise = np.random.random(images.shape.as_list()[1:]) 
                # noise = noise.astype(np.float32)
                # mix = 0.25
                # current_batch = (1 - mix) * current_batch + mix * noise[np.newaxis] 
                _, constr_val = sess.run([reg_train, constr_loss], feed_dict={psi_out_true:current_batch})
                # if doLog:
                print_rank('\t\tstep={}, reg iter={}, constr_loss={:2.3e}'.format(train_elf.last_step, itr, constr_val))
                    # print('\t\trank={}, step={}, reg iter={}, constr_loss={:2.3e}'.format(hvd.rank(), train_elf.last_step, itr, constr_val))
            del batch_buffer
            batch_buffer = []
            # for i in range(len(IO_ops)):
            #     sess.run(IO_ops[:i + 1])
            
    val_results.append((train_elf.last_step,val))
    tf.reset_default_graph()
    tf.keras.backend.clear_session()