Commit a3098084 authored by Laanait, Nouamane's avatar Laanait, Nouamane

modifying inner training loop of YNet

parent 10a69fbd
Pipeline #81959 failed with stage
in 2 minutes and 24 seconds
......@@ -83,13 +83,13 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
pot_labels, probe_labels_re, probe_labels_im = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
#weight= np.prod(pot_labels.shape.as_list()[-2:])
weight=None
inv_str = hyper_params.get('inv_strength', 0.1)
reg_str = hyper_params.get('reg_strength', 0.1)
inv_str = hyper_params.get('inv_strength', 1)
reg_str = hyper_params.get('reg_strength', 0.01)
dec_str = hyper_params.get('dec_strength', 1)
inverter_loss = inv_str * calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight)
decoder_loss_im = dec_str * calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight)
decoder_loss_re = dec_str * calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight)
psi_out_mod = thin_object(probe_re, probe_im, pot)
psi_out_mod = thin_object(probe_re, probe_im, pot, summarize=False)
reg_loss = reg_str * calculate_loss_regressor(psi_out_mod, tf.reduce_mean(images, axis=[1], keepdims=True),
params, hyper_params, weight=weight)
tf.summary.scalar('reg_loss ', reg_loss)
......@@ -111,7 +111,6 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
#Assemble all of the losses.
losses = tf.get_collection(tf.GraphKeys.LOSSES)
if hyper_params['network_type'] == 'YNet':
reg_str = hyper_params.get('reg_strength', 0.1)
losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_loss]
# losses, prefac = ynet_adjusted_losses(losses, step)
regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
......@@ -125,6 +124,16 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
return total_loss, loss_averages_op, losses
return total_loss, loss_averages_op
def get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=1):
probe_im = tf.cast(n_net.model_output['decoder_IM'], tf.float32)
probe_re = tf.cast(n_net.model_output['decoder_RE'], tf.float32)
pot = tf.cast(n_net.model_output['inverter'], tf.float32)
psi_out_mod = thin_object(probe_re, probe_im, pot)
reg_loss = calculate_loss_regressor(psi_out_mod, tf.reduce_mean(psi_out_true, axis=[1], keepdims=True),
params, hyper_params, weight=weight)
reg_loss = tf.cast(reg_loss, tf.float32)
return reg_loss
def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None):
input = tf.cast(tf.reshape(n_net.model_output,[batch_size, -1]), tf.float32)
dim_input = input.shape[1].value
......@@ -198,7 +207,7 @@ def calculate_loss_regressor(net_output, labels, params, hyper_params, weight=No
reduction=tf.losses.Reduction.MEAN)
if loss_params['type'] == 'MSE':
cost = tf.losses.mean_squared_error(labels, weights=weight, predictions=net_output,
reduction=tf.losses.Reduction.SUM)
reduction=tf.losses.Reduction.MEAN)
if loss_params['type'] == 'ABS_DIFF':
cost = tf.losses.absolute_difference(labels, weights=weight, predictions=net_output,
reduction=tf.losses.Reduction.MEAN)
......@@ -251,7 +260,7 @@ def fftshift(tensor, tens_format='NCHW'):
shift_tensor = manip_ops.roll(tensor, shift, dims)
return shift_tensor
def thin_object(psi_k_re, psi_k_im, potential):
def thin_object(psi_k_re, psi_k_im, potential, summarize=True):
# mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32)
# ratio = 0
# if ratio == 0:
......@@ -270,6 +279,7 @@ def thin_object(psi_k_re, psi_k_im, potential):
psi_out = tf.fft2d(psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list()))
psi_out_mod = tf.cast(tf.abs(psi_out), tf.float32) ** 2
psi_out_mod = tf.reduce_mean(psi_out_mod, axis=1, keep_dims=True)
tf.summary.image('Psi_k_out', tf.transpose(tf.abs(psi_out_mod)**0.25, perm=[0,2,3,1]), max_outputs=1)
tf.summary.image('Psi_x_in', tf.transpose(tf.abs(psi_x)**0.25, perm=[0,2,3,1]), max_outputs=1)
if summarize:
tf.summary.image('Psi_k_out', tf.transpose(tf.abs(psi_out_mod)**0.25, perm=[0,2,3,1]), max_outputs=1)
tf.summary.image('Psi_x_in', tf.transpose(tf.abs(psi_x)**0.25, perm=[0,2,3,1]), max_outputs=1)
return psi_out_mod
......@@ -708,8 +708,10 @@ class ConvNet:
return tf.nn.tanh(input, name=name)
elif params['activation'] == 'leaky_relu':
return tf.nn.leaky_relu(input, name=name)
else:
elif params['activation'] == 'relu':
return tf.nn.relu(input, name=name)
elif params['activation'] == 'none':
return input
else:
return input
......
......@@ -312,7 +312,7 @@ def optimize_loss(loss,
# Compute gradients.
grads_and_vars = opt.compute_gradients(
loss, colocate_gradients_with_ops=True, var_list=var_list
loss, var_list=var_list
)
if dtype == 'mixed' or dtype == tf.float16:
......
......@@ -13,6 +13,7 @@ import numpy as np
import math
from itertools import chain
from multiprocessing import cpu_count
from copy import deepcopy
#TF
import tensorflow as tf
......@@ -543,7 +544,13 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
# Build it and propagate images through it.
n_net.build_model()
# # Stop gradients
# stop_op = tf.stop_gradient(n_net.model_output['encoder'])
# calculate the total loss
# psi_out_true = tf.placeholder(tf.float32, shape=images.shape.as_list(), name="psi_out_true")
psi_out_true = images
constr_loss = losses.get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=10)
total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary)
#get summaries, except for the one produced by string_input_producer
......@@ -573,9 +580,21 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
loss_scaling=hyper_params.get('loss_scaling',1.0),
skip_update_cond=skip_update_cond,
on_horovod=True, model_scopes=n_net.scopes)
# optimizer for regularization step
# var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)]
var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)]
# var_list = None
# print_rank(var_list)
opt = tf.train.MomentumOptimizer(1e-5, 0.9)
reg_opt = opt.minimize(constr_loss, var_list=var_list)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# with tf.control_dependencies([tf.group(*[stop_op, reg_opt, update_ops])]):
with tf.control_dependencies([tf.group(*[reg_opt, update_ops])]):
reg_train = tf.no_op(name='reg_train')
# Gather all training related ops into a single one.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
increment_op = tf.assign_add(global_step, 1)
ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step)
all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op]))
......@@ -583,7 +602,7 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
with tf.control_dependencies([all_ops]):
train_op = ema.apply(tf.trainable_variables())
# train_op = tf.no_op(name='train')
########################
# Setting up Summaries #
########################
......@@ -654,7 +673,10 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
loss_results = []
loss_value = 1e10
val = 1e10
current_batch = np.zeros(images.shape.as_list(), dtype=np.float32)
batch_buffer = []
while train_elf.last_step < maxSteps :
# batch_buffer.append(images.eval(session=sess))
train_elf.before_run()
doLog = bool(train_elf.last_step % logFreq == 0)
doSave = bool(train_elf.last_step % saveStep == 0)
......@@ -663,17 +685,17 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
doValidate = bool(train_elf.last_step % validateStep == 0)
doFinish = bool(train_elf.start_time - params['start_time'] > maxTime)
if train_elf.last_step == 1 and params['debug']:
summary = sess.run([train_op, summary_merged])[-1]
_, summary, current_batch = sess.run([train_op, summary_merged, images])
train_elf.write_summaries( summary )
elif not doLog and not doSave and not doTrace and not doSumm:
sess.run(train_op)
_, current_batch = sess.run([train_op, images])
elif doLog and not doSave and not doSumm:
_, lr, loss_value, aux_losses = sess.run( [ train_op, learning_rate, total_loss, indv_losses])
_, lr, loss_value, aux_losses, current_batch = sess.run( [ train_op, learning_rate, total_loss, indv_losses, images])
loss_results.append((train_elf.last_step, loss_value))
train_elf.log_stats( loss_value, aux_losses, lr)
elif doLog and doSumm and doSave :
_, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses,
learning_rate ])
_, summary, loss_value, aux_losses, lr, current_batch = sess.run( [ train_op, summary_merged, total_loss, indv_losses,
learning_rate, images ])
loss_results.append((train_elf.last_step, loss_value))
train_elf.log_stats( loss_value, aux_losses, lr )
train_elf.write_summaries( summary )
......@@ -681,12 +703,12 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
print_rank('Saved Checkpoint.')
elif doLog and doSumm :
_, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ])
_, summary, loss_value, aux_losses, lr, current_batch = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate, images ])
loss_results.append((train_elf.last_step, loss_value))
train_elf.log_stats( loss_value, aux_losses, lr )
train_elf.write_summaries( summary )
elif doSumm:
summary = sess.run([train_op, summary_merged])[-1]
_, summary, current_batch = sess.run([train_op, summary_merged, images])
train_elf.write_summaries( summary )
elif doSave :
if hvd.rank( ) == 0 :
......@@ -709,6 +731,27 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
return val_results, loss_results
if np.isnan(loss_value):
break
# if doLog:
# constr_val = sess.run(constr_loss, feed_dict={psi_out_true:current_batch})
# print_rank('\t\tstep={}, current constr_loss={:2.3e}'.format(train_elf.last_step, constr_val))
# current_batch_list = []
batch_buffer.append(current_batch)
# print_rank(len(batch_buffer))
if bool(train_elf.last_step % 10 == 0 and train_elf.last_step >= 10):
for itr, current_batch in enumerate(batch_buffer):
# noise = np.random.random(images.shape.as_list()[1:])
# noise = noise.astype(np.float32)
# mix = 0.25
# current_batch = (1 - mix) * current_batch + mix * noise[np.newaxis]
_, constr_val = sess.run([reg_train, constr_loss], feed_dict={psi_out_true:current_batch})
# if doLog:
print_rank('\t\tstep={}, reg iter={}, constr_loss={:2.3e}'.format(train_elf.last_step, itr, constr_val))
# print('\t\trank={}, step={}, reg iter={}, constr_loss={:2.3e}'.format(hvd.rank(), train_elf.last_step, itr, constr_val))
del batch_buffer
batch_buffer = []
# for i in range(len(IO_ops)):
# sess.run(IO_ops[:i + 1])
val_results.append((train_elf.last_step,val))
tf.reset_default_graph()
tf.keras.backend.clear_session()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment