Commit 81ddd819 authored by Laanait, Nouamane's avatar Laanait, Nouamane

minor changes to validate function for ynet

parent 26de6b23
Pipeline #80459 failed with stage
in 2 minutes and 39 seconds
......@@ -108,11 +108,9 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
#Assemble all of the losses.
losses = tf.get_collection(tf.GraphKeys.LOSSES)
if hyper_params['network_type'] == 'YNet':
losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_loss]
# losses = [inverter_loss , decoder_loss_re, decoder_loss_im]
reg_str = hyper_params.get('reg_strength', 0.1)
losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_str * reg_loss]
# losses, prefac = ynet_adjusted_losses(losses, step)
# tf.summary.scalar("prefac_inverter", prefac)
# losses = [inverter_loss]
regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
# Calculate the total loss
total_loss = tf.add_n(losses + regularization, name='total_loss')
......
......@@ -34,7 +34,7 @@ def decay_warmup(params, hyper_params, global_step):
# Decay/ramp the learning rate exponentially based on the number of steps.
def ramp():
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, ramp_steps, LEARNING_RATE_DECAY_FACTOR,
staircase=False)
staircase=True)
lr = INITIAL_LEARNING_RATE ** 2 * tf.pow(lr, tf.constant(-1.))
lr = tf.minimum(lr,WARM_UP_LEARNING_RATE_MAX)
return lr
......@@ -50,10 +50,18 @@ def decay_warmup(params, hyper_params, global_step):
staircase=True)
return lr
warm_up_slope = hyper_params.get('warm_up_slope', 1.)
def constant_ramp():
lr = tf.cast(INITIAL_LEARNING_RATE, tf.float32) * (tf.cast(global_step, tf.float32) * warm_up_slope + 1)
lr = tf.math.minimum(tf.cast(WARM_UP_LEARNING_RATE_MAX, tf.float32), lr)
return lr
if hyper_params['warm_up']:
# LEARNING_RATE = tf.cond(global_step < ramp_up_steps, ramp, lambda: decay(ramp()))
LEARNING_RATE = tf.cond(global_step < ramp_up_steps, linear_ramp, lambda: decay(linear_ramp()))
#LEARNING_RATE = tf.cond(global_step < ramp_up_steps, linear_ramp, lambda: decay(linear_ramp()))
ramp_up_steps = tf.cast(WARM_UP_LEARNING_RATE_MAX/INITIAL_LEARNING_RATE, global_step.dtype)
LEARNING_RATE = tf.cond(global_step < ramp_up_steps, constant_ramp, lambda: decay(constant_ramp()))
else:
LEARNING_RATE = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps,
LEARNING_RATE_DECAY_FACTOR, staircase=True)
......
......@@ -1247,7 +1247,7 @@ class FCDenseNet(ConvNet):
def _freq2space(self, inputs=None):
shape = inputs.shape
weights_dim = 512
weights_dim = 256
num_fc = 2
# if weights_dim < 4096 :
fully_connected = OrderedDict({'type': 'fully_connected','weights': weights_dim,'bias': weights_dim, 'activation': 'relu',
......@@ -1266,7 +1266,8 @@ class FCDenseNet(ConvNet):
out = tf.reshape(out, [shape[0], -1])
for i in range(num_fc):
if i > 0:
weights_dim = min(4096, int(shape.as_list()[-2]**2))
#weights_dim = min(4096, int(shape.as_list()[-2]**2))
weights_dim = min(weights_dim, int(shape.as_list()[-2]**2))
fully_connected['weights'] = weights_dim
fully_connected['bias'] = weights_dim
with tf.variable_scope('FC_%d' %i, reuse=self.reuse) as _ :
......
......@@ -410,8 +410,8 @@ def train(network_config, hyper_params, params, gpu_id=None):
val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
val_results.append((train_elf.last_step,val))
if doFinish:
val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
val_results.append((train_elf.last_step, val))
#val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
#val_results.append((train_elf.last_step, val))
tf.reset_default_graph()
tf.keras.backend.clear_session()
sess.close()
......@@ -518,9 +518,13 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
print('not implemented')
elif hyper_params['network_type'] == 'YNet':
loss_params = hyper_params['loss_function']
model_output = tf.concat([n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']], axis=1)
#model_output = tf.concat([n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']], axis=1)
model_output = [n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']]
labels = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
if loss_params['type'] == 'MSE_PAIR':
errors = tf.losses.mean_pairwise_squared_error(tf.cast(labels, tf.float32), tf.cast(model_output, tf.float32))
errors = [tf.losses.mean_pairwise_squared_error(tf.cast(label, tf.float32), out)
for label, out in zip(labels, model_output)]
errors = tf.stack(errors)
loss_label= loss_params['type']
elif loss_params['type'] == 'ABS_DIFF':
loss_label= 'ABS_DIFF'
......@@ -535,8 +539,8 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
elif num_batches > dset.num_samples:
num_samples = dset.num_samples
errors = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples//params['batch_size'])])
result = errors.mean()
print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
result = errors.mean(0)
print_rank('Validation Reconstruction Error %s: '% loss_label, result)
elif hyper_params['network_type'] == 'inverter':
loss_params = hyper_params['loss_function']
if labels.shape.as_list()[1] > 1:
......@@ -563,7 +567,7 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
num_samples = dset.num_samples
errors = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples//params['batch_size'])])
result = errors.mean()
print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, result))
tf.summary.scalar("Validation_loss_label_%s" % loss_label, tf.constant(errors.mean()))
return result
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment