Commit 2b787e5e authored by Laanait, Nouamane's avatar Laanait, Nouamane

adding YNet loss hyper-params

parent 6d8ddd4d
Pipeline #80854 failed with stage
in 7 minutes and 50 seconds
......@@ -82,12 +82,15 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
pot = n_net.model_output['inverter']
pot_labels, probe_labels_re, probe_labels_im = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
#weight= np.prod(pot_labels.shape.as_list()[-2:])
weight=None
inverter_loss = calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight)
decoder_loss_im = calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight)
decoder_loss_re = calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight)
weight=None
inv_str = hyper_params.get('inv_strength', 0.1)
reg_str = hyper_params.get('reg_strength', 0.1)
dec_str = hyper_params.get('dec_strength', 1)
inverter_loss = inv_str * calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight)
decoder_loss_im = dec_str * calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight)
decoder_loss_re = dec_str * calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight)
psi_out_mod = thin_object(probe_re, probe_im, pot)
reg_loss = calculate_loss_regressor(psi_out_mod, tf.reduce_mean(images, axis=[1], keepdims=True),
reg_loss = reg_str * calculate_loss_regressor(psi_out_mod, tf.reduce_mean(images, axis=[1], keepdims=True),
params, hyper_params, weight=weight)
tf.summary.scalar('reg_loss ', reg_loss)
tf.summary.scalar('Inverter loss ', inverter_loss)
......@@ -109,7 +112,7 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
losses = tf.get_collection(tf.GraphKeys.LOSSES)
if hyper_params['network_type'] == 'YNet':
reg_str = hyper_params.get('reg_strength', 0.1)
losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_str * reg_loss]
losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_loss]
# losses, prefac = ynet_adjusted_losses(losses, step)
regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
# Calculate the total loss
......@@ -249,22 +252,22 @@ def fftshift(tensor, tens_format='NCHW'):
return shift_tensor
def thin_object(psi_k_re, psi_k_im, potential):
mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32)
ratio = 0
if ratio == 0:
center = slice(None, None)
else:
center = slice(int(ratio * mask.shape[-1]), int((1-ratio)* mask.shape[-1]))
mask[:,:,center,center] = 1.
mask = tf.constant(mask, dtype=tf.complex64)
psi_x = fftshift(tf.ifft2d(mask * tf.cast(psi_k_re, tf.complex64) * tf.exp( 1.j * tf.cast(psi_k_im, tf.complex64))))
# mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32)
# ratio = 0
# if ratio == 0:
# center = slice(None, None)
# else:
# center = slice(int(ratio * mask.shape[-1]), int((1-ratio)* mask.shape[-1]))
# mask[:,:,center,center] = 1.
# mask = tf.constant(mask, dtype=tf.complex64)
psi_x = fftshift(tf.ifft2d(tf.cast(psi_k_re, tf.complex64) * tf.exp( 1.j * tf.cast(psi_k_im, tf.complex64))))
scan_range = psi_x.shape.as_list()[-1]//2
vx, vy = np.linspace(-scan_range, scan_range, num=4), np.linspace(-scan_range, scan_range, num=4)
X, Y = np.meshgrid(vx.astype(np.int), vy.astype(np.int))
psi_x_stack = [tf.roll(psi_x, shift=[x,y], axis=[1,2]) for (x,y) in zip(X.flatten(), Y.flatten())]
psi_x_stack = tf.concat(psi_x_stack, axis=1)
pot_frac = tf.exp(1.j * tf.cast(potential, tf.complex64))
psi_out = tf.fft2d(mask * psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list()))
psi_out = tf.fft2d(psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list()))
psi_out_mod = tf.cast(tf.abs(psi_out), tf.float32) ** 2
psi_out_mod = tf.reduce_mean(psi_out_mod, axis=1, keep_dims=True)
tf.summary.image('Psi_k_out', tf.transpose(tf.abs(psi_out_mod)**0.25, perm=[0,2,3,1]), max_outputs=1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment