Commit 6d8ddd4d authored by Laanait, Nouamane's avatar Laanait, Nouamane

new training function and class for YNet

parent 10d37686
Pipeline #80781 failed with stage
in 2 minutes and 34 seconds
......@@ -118,6 +118,8 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
total_loss = tf.cast(total_loss, tf.float32)
# Generate summaries for the losses and get corresponding op
loss_averages_op = _add_loss_summaries(total_loss, losses, summaries=summary)
if hyper_params['network_type'] == 'YNet':
return total_loss, loss_averages_op, losses
return total_loss, loss_averages_op
def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None):
This diff is collapsed.
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment