Commit 2b787e5e authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

adding YNet loss hyper-params

parent 6d8ddd4d
Loading
Loading
Loading
Loading
Loading
+19 −16
Original line number Diff line number Diff line
@@ -83,11 +83,14 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
        pot_labels, probe_labels_re, probe_labels_im = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
        #weight= np.prod(pot_labels.shape.as_list()[-2:])
        weight=None
        inverter_loss = calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight)
        decoder_loss_im = calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight)
        decoder_loss_re = calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight)
        inv_str = hyper_params.get('inv_strength', 0.1)
        reg_str = hyper_params.get('reg_strength', 0.1)
        dec_str = hyper_params.get('dec_strength', 1) 
        inverter_loss = inv_str * calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight)
        decoder_loss_im = dec_str * calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight)
        decoder_loss_re = dec_str * calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight)
        psi_out_mod = thin_object(probe_re, probe_im, pot)
        reg_loss = calculate_loss_regressor(psi_out_mod, tf.reduce_mean(images, axis=[1], keepdims=True), 
        reg_loss = reg_str * calculate_loss_regressor(psi_out_mod, tf.reduce_mean(images, axis=[1], keepdims=True), 
                    params, hyper_params, weight=weight)
        tf.summary.scalar('reg_loss ', reg_loss)
        tf.summary.scalar('Inverter loss ', inverter_loss)
@@ -109,7 +112,7 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
    losses = tf.get_collection(tf.GraphKeys.LOSSES)
    if hyper_params['network_type'] == 'YNet':
        reg_str = hyper_params.get('reg_strength', 0.1)
        losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_str * reg_loss]
        losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_loss]
        # losses, prefac = ynet_adjusted_losses(losses, step)
    regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    # Calculate the total loss 
@@ -249,22 +252,22 @@ def fftshift(tensor, tens_format='NCHW'):
    return shift_tensor

def thin_object(psi_k_re, psi_k_im, potential):
    mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32)
    ratio = 0
    if ratio == 0:
        center = slice(None, None) 
    else:
        center = slice(int(ratio * mask.shape[-1]), int((1-ratio)* mask.shape[-1]))
    mask[:,:,center,center] = 1.
    mask = tf.constant(mask, dtype=tf.complex64)
    psi_x = fftshift(tf.ifft2d(mask * tf.cast(psi_k_re, tf.complex64) * tf.exp( 1.j * tf.cast(psi_k_im, tf.complex64))))
    # mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32)
    # ratio = 0
    # if ratio == 0:
    #     center = slice(None, None) 
    # else:
    #     center = slice(int(ratio * mask.shape[-1]), int((1-ratio)* mask.shape[-1]))
    # mask[:,:,center,center] = 1.
    # mask = tf.constant(mask, dtype=tf.complex64)
    psi_x = fftshift(tf.ifft2d(tf.cast(psi_k_re, tf.complex64) * tf.exp( 1.j * tf.cast(psi_k_im, tf.complex64))))
    scan_range = psi_x.shape.as_list()[-1]//2
    vx, vy = np.linspace(-scan_range, scan_range, num=4), np.linspace(-scan_range, scan_range, num=4)
    X, Y = np.meshgrid(vx.astype(np.int), vy.astype(np.int))
    psi_x_stack = [tf.roll(psi_x, shift=[x,y], axis=[1,2]) for (x,y) in zip(X.flatten(), Y.flatten())]
    psi_x_stack = tf.concat(psi_x_stack, axis=1)
    pot_frac = tf.exp(1.j * tf.cast(potential, tf.complex64))
    psi_out = tf.fft2d(mask * psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list()))
    psi_out = tf.fft2d(psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list()))
    psi_out_mod = tf.cast(tf.abs(psi_out), tf.float32) ** 2
    psi_out_mod = tf.reduce_mean(psi_out_mod, axis=1, keep_dims=True)
    tf.summary.image('Psi_k_out', tf.transpose(tf.abs(psi_out_mod)**0.25, perm=[0,2,3,1]), max_outputs=1)