Commit a951a885 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

filtering fft before use in ynet loss functions

parent 6f8727e1
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -258,7 +258,8 @@ def fftshift(tensor, tens_format='NCHW'):

def thin_object(psi_k_re, psi_k_im, potential):
    mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32)
    center = slice(mask.shape[-1]//5, 4 * mask.shape[-1]//5) 
    ratio = 0.33
    center = slice(int(ratio * mask.shape[-1]), int((1-ratio)* mask.shape[-1]))
    mask[:,:,center,center] = 1.
    mask = tf.constant(mask, dtype=tf.complex64)
    psi_x = fftshift(tf.ifft2d(mask * tf.cast(psi_k_re, tf.complex64) * tf.exp( 1.j * tf.cast(psi_k_im, tf.complex64))))