Commit c8a6f3b3 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

Mods to ynet model builder (upsampling, residual blocks,...)


Former-commit-id: 8fa3f4b6
parent 2af6d2b7
Loading
Loading
Loading
Loading
+6 −14
Original line number Diff line number Diff line
@@ -74,21 +74,11 @@ def calc_loss(n_net, scope, hyper_params, params, labels, images=None, summary=F
        _ = calculate_loss_regressor(n_net.model_output, labels, params, hyper_params)
    if hyper_params['network_type'] == 'YNet':
        weight=None
        probe_im = n_net.model_output['decoder']['IM']
        probe_re = n_net.model_output['decoder']['RE']
        probe_im = n_net.model_output['decoder_IM']
        probe_re = n_net.model_output['decoder_RE']
        pot = n_net.model_output['inverter']
        # probe_labels = tf.transpose(tf.reduce_mean(images, axis=1, keepdims=True), perm=[0,2,3,1])
        # probe_labels = tf.image.resize_bilinear(probe_labels, [128,128])
        # probe_labels = tf.transpose(probe_labels, perm=[0,3,1,2])
        # probe_arr = np.load('probe_amp.npy')
        # probe_arr = np.load('psi_k.npy')
        pot_labels, probe_labels_re, probe_labels_im = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
        # probe_arr = np.expand_dims(np.expand_dims(probe_arr, axis=0), axis=0)
        # probe_arr = np.tile(probe_arr, [4,1,1,1])
        # probe_labels_re = tf.constant(np.abs(probe_arr), dtype=tf.float32)
        # probe_labels_im = tf.constant(np.angle(probe_arr), dtype=tf.float32)
        # pot_labels = labels
        # weight=10
        # weight=0.10
        inverter_loss = calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight)
        # weight=1
        decoder_loss_im = calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight)
@@ -111,7 +101,9 @@ def calc_loss(n_net, scope, hyper_params, params, labels, images=None, summary=F
    #Assemble all of the losses.
    losses = tf.get_collection(tf.GraphKeys.LOSSES)
    if hyper_params['network_type'] == 'YNet':
        losses = [inverter_loss , decoder_loss_im, decoder_loss_re ]
        # losses = [inverter_loss, decoder_loss_re]
        losses = [inverter_loss , decoder_loss_re, decoder_loss_im ]
        # losses = [inverter_loss , decoder_loss_im ]
    # losses = [inverter_loss]
    regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    # Calculate the total loss 
+183 −64
Original line number Diff line number Diff line
@@ -2131,7 +2131,7 @@ class YNet(FCDenseNet, FCNet):
    """
    def __init__(self, *args, **kwargs):
        super(YNet, self).__init__(*args, **kwargs)
        self.all_scopes = {"encoder": None, "decoder": None, "inverter": None}
        self.all_scopes = {"encoder": None, "decoder_RE": None, "decoder_IM": None, "inverter": None}
        self.all_ops = {"encoder": 0., "decoder": 0., "inverter": 0.}
        self.all_weights = {"encoder": 0., "decoder": 0., "inverter": 0.}
        self.all_mem = {"encoder": 0., "decoder": 0., "inverter": 0.}
@@ -2140,6 +2140,10 @@ class YNet(FCDenseNet, FCNet):
        self.model_output = {"encoder": None, "decoder": None, "inverter": None}
        self.network = dict([(key, itm) for key,itm in self.network.items()])

    def _batch_norm(self, input=None):
        out = tf.keras.layers.BatchNormalization(axis=1)(inputs=input, training= self.operation == 'train')
        return out

    def get_all_ops(self, subnet=None):
        if subnet is None:
            return 3 * np.sum([op for _, op in self.all_ops.items()])
@@ -2509,33 +2513,32 @@ class YNet(FCDenseNet, FCNet):
                    tens , _ = self._conv(input=tens, params=conv_1by1)
                    # tens = self._pool(input=tens, params=pool)
                    tens = self._activate(input=tens, params=conv_1by1)
            if tens.shape[-2:] != [32, 32]:
                tens = tf.transpose(tens, perm=[0, 2, 3, 1])
                tens = tf.image.resize(tens, [32, 32], method=tf.image.ResizeMethod.BILINEAR)
                if self.params['IMAGE_FP16']:
                    tens = tf.saturate_cast(tens, tf.float16)
                tens = tf.transpose(tens, perm=[0, 3, 1, 2])
            # if tens.shape[-2:] != [32, 32]:
            #     tens = tf.transpose(tens, perm=[0, 2, 3, 1])
            #     tens = tf.image.resize(tens, [32, 32], method=tf.image.ResizeMethod.BILINEAR)
            #     if self.params['IMAGE_FP16']:
            #         tens = tf.saturate_cast(tens, tf.float16)
            #     tens = tf.transpose(tens, perm=[0, 3, 1, 2])
            # self.print_rank('shape inside CVAE', tens.get_shape())
            for i in range(num_fc):
                with tf.variable_scope('CVAE_fc_%d' %i, reuse=self.reuse) as _ :
                    tens = self._linear(input=tens, params=fully_connected)
                    tens = self._activate(input=tens, params=fully_connected)
            # tens = tf.reshape(tens, [new_shape[0], -1])
            # for i in range(num_fc):
            #     with tf.variable_scope('CVAE_fc_%d' %i, reuse=self.reuse) as _ :
            #         tens = self._linear(input=tens, params=fully_connected)
            #         tens = self._activate(input=tens, params=fully_connected)
            # # tens = tf.reshape(tens, [new_shape[0], -1])
            return tens

        # post_ops = deepcopy(self.ops)
        # self.print_rank("post pre, cvae ops: ", pre_ops - post_ops)
        out = tf.map_fn(CVAE, tensor_slices, back_prop=True, swap_memory=True, parallel_iterations=256)
        # self.print_rank('output of CVAE', out.get_shape())
        out = tf.transpose(out, perm= [1, 2, 0])
        dim = int(math.sqrt(self.images.shape.as_list()[1]))
        out = tf.reshape(out, [self.params['batch_size'], -1, dim, dim])
        # out = tf.transpose(out, perm= [1, 2, 0])
        # dim = int(math.sqrt(self.images.shape.as_list()[1]))
        # out = tf.reshape(out, [self.params['batch_size'], -1, dim, dim])
        # out = tf.transpose(out, perm=[1,0,2,3])
        self.print_rank('output of Encoder', out.get_shape())
        self.model_output['encoder'] = out 
        self.update_all_attrs(subnet='encoder')



    def fully_connected_block(self, inputs, layer_params, scope_name):
        fc_params = OrderedDict({'type': 'fully_connected','weights': 1024,'bias': 1024, 'activation': layer_params['activation'],
                                   'regularize': True})
@@ -2568,10 +2571,11 @@ class YNet(FCDenseNet, FCNet):
                out = self._activate(input=out, params=fc_params) 
        return out

    def _build_branch(self, subnet='decoder', inputs=None):
    def _build_branch(self, subnet='decoder', scope= None, inputs=None):
        self.scopes = []
        self.print_rank('***** %s Branch ******' % subnet)
        network = self.network[subnet]
        subnet_scope = subnet if scope is None else scope
        # out = self.model_output['encoder']
        # if inputs is not None:
        #     out = inputs
@@ -2581,7 +2585,7 @@ class YNet(FCDenseNet, FCNet):
        out = inputs

        for layer_num, (layer_name, layer_params) in enumerate(list(network.items())):
            with tf.variable_scope(subnet+'_'+layer_name, reuse=self.reuse) as scope:
            with tf.variable_scope(subnet_scope+'_'+layer_name, reuse=self.reuse) as scope:
                in_shape = out.get_shape().as_list()
             
                if layer_params['type'] == 'conv_2D':
@@ -2607,6 +2611,15 @@ class YNet(FCDenseNet, FCNet):
                        self._activation_summary(out)
                        self._activation_image_summary(out)
                
                if layer_params['type'] == 'residual':
                    self.print_verbose(">>> Adding Residual Block: %s" % layer_name)
                    out, _ = self._residual_unit(inputs=out, params=layer_params)
                    out_shape = out.get_shape().as_list()
                    self.print_verbose('    output: %s' %format(out.get_shape().as_list()))
                    if self.summary: 
                        self._activation_summary(out)
                        self._activation_image_summary(out)

                if layer_params['type'] == 'depth_conv':
                    self.print_verbose(">>> Adding depthwise Conv Layer: %s" % layer_name)
                    self.print_verbose('    input: %s' %format(out.get_shape().as_list()))
@@ -2644,6 +2657,9 @@ class YNet(FCDenseNet, FCNet):
                if layer_params['type'] == 'deconv_2D':
                    self.print_verbose(">>> Adding de-Conv Layer: %s" % layer_name)
                    self.print_verbose('    input: %s' %format(out.get_shape().as_list()))
                    if subnet == 'inverter':
                        out = self._upscale(inputs=out, params=layer_params)
                    else:
                        out, _ = self._deconv(input=out, params=layer_params)
                    self.print_verbose('    output: %s' %format(out.get_shape().as_list()))
                    if self.summary: 
@@ -2692,21 +2708,50 @@ class YNet(FCDenseNet, FCNet):
                                                                                    self.num_weights,
                                                                                    format(self.mem / 1024),
                                                                                    self.get_ops()))
        if subnet == 'inverter':
            out = tf.reduce_mean(out, axis=1, keepdims=True)
        self.model_output[subnet] = out
        self.update_all_attrs(subnet=subnet)
    
    def build_decoder(self):
    def _upscale(self, inputs=None, params=None, scale=2):
        conv_params = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
                                'features': inputs.shape.as_list()[1]//2,
                                'activation': 'relu', 
                                'padding': 'VALID', 
                                'batch_norm': False, 'dropout':0.0})
        with tf.variable_scope('upscale', reuse=self.reuse) as _:
            shape = inputs.shape
            out = tf.reshape(inputs, [-1, shape[1], shape[2], 1, shape[3], 1])
            out = tf.tile(out, [1, 1, 1, scale, 1, scale])
            out = tf.reshape(out, [-1, shape[1], shape[2] * scale, shape[3] * scale])
            out, _ = self._conv(input=out, params=conv_params)
            return out

    def _residual_unit(self, inputs=None, params=None):
        conv_params = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
                                'features': inputs.shape.as_list()[1],
                                'activation': 'relu', 
                                'padding': 'VALID', 
                                'batch_norm': False, 'dropout':0.0})
        out = self._batch_norm(input=inputs)
        out = self._activate(input=out, params=conv_params)
        with tf.variable_scope('residual_conv_1', reuse=self.reuse) as scope:
            out, _ = self._conv(input=out, params=conv_params)
        out = self._batch_norm(input=out)
        out = self._activate(input=out, params=conv_params)
        with tf.variable_scope('residual_conv_2', reuse=self.reuse) as scope:
            out, _ = self._conv(input=out, params=conv_params)
        out = tf.add(inputs, out)
        return out, None

    def build_decoder_RE(self):
        out = self.model_output['encoder']
        # out = self.fully_connected_block(out, self.network['encoder']['fully_connected_block'], 'decoder')
        # out = tf.transpose(out, perm=[0, 2, 1])
        # dim = int(math.sqrt(self.images.shape.as_list()[1]))
        # out = tf.reshape(out, [self.params['batch_size'], -1, dim, dim])
        conv_1by1 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
                                'features': 1024,
                                'activation': 'relu', 
                                'padding': 'VALID', 
                                'batch_norm': False, 'dropout':0.0})
        with tf.variable_scope('decoder_conv_1by1', reuse=self.reuse) as _:
        with tf.variable_scope('decoder_RE_conv_1by1', reuse=self.reuse) as _:
            out, _ = self._conv(input=out, params=conv_1by1) 
            do_bn = conv_1by1.get('batch_norm', False)
            if do_bn:
@@ -2724,67 +2769,141 @@ class YNet(FCDenseNet, FCNet):
        #         out = self._activate(input=out, params=fc_params)
        # dim = int(math.sqrt(fc_params['weights']))
        # out = tf.reshape(out, [self.params['batch_size'], 1, dim, dim])
        self._build_branch(subnet='decoder', inputs=out)
        self._build_branch(subnet='decoder_RE', inputs=out)

        conv_1by1 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 'features': 1,
                            'activation': 'relu', 'padding': 'SAME', 'batch_norm': False})
        with tf.variable_scope('decoder_CONV_FIN_RE', reuse=self.reuse) as _:
            out, _ = self._conv(input=self.model_output['decoder'], params=conv_1by1)
        with tf.variable_scope('decoder_RE_CONV_FIN', reuse=self.reuse) as _:
            out, _ = self._conv(input=self.model_output['decoder_RE'], params=conv_1by1)
            do_bn = conv_1by1.get('batch_norm', False)
            if do_bn:
                out = self._batch_norm(input=out)
            else:
                out = self._add_bias(input=out, params=conv_1by1)
            out = self._activate(input=out, params=conv_1by1)
        outputs = {'RE': out, 'IM': None}
        with tf.variable_scope('decoder_CONV_FIN_IM', reuse=self.reuse) as _:
            out, _ = self._conv(input=self.model_output['decoder'], params=conv_1by1)
            do_bn = conv_1by1.get('batch_norm', False)
        self.model_output['decoder_RE'] = out

    def build_decoder(self, subnet='decoder_IM'):
        scopes_list = []
        out = self.model_output['encoder']
        out_shape = out.shape.as_list()
        params = self.network['encoder']['freq2space']['cvae_params']
        fully_connected = params['fc_params']
        num_fc = params['n_fc_layers']
        conv_1by1_1 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
                                'features': 1,
                                'activation': 'relu', 
                                'padding': 'VALID', 
                                'batch_norm': False, 'dropout':0.0})
        conv_1by1_1024 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
        'features': 1024,
        'activation': 'relu', 
        'padding': 'VALID', 
        'batch_norm': False, 'dropout':0.0})
        if False:
            def fc_map(tens):
                for i in range(num_fc):
                    with tf.variable_scope('%s_fc_%d' %(subnet, i), reuse=self.reuse) as scope :
                        tens = self._linear(input=tens, params=fully_connected)
                        tens = self._activate(input=tens, params=fully_connected)
                        # scopes_list.append(scope)
                return tens
            out = tf.map_fn(fc_map, out, back_prop=True, swap_memory=True, parallel_iterations=256)
            out = tf.transpose(out, perm= [1, 2, 0])
            dim = int(math.sqrt(self.images.shape.as_list()[1]))
            out = tf.reshape(out, [self.params['batch_size'], -1, dim, dim])
        else:
            out = tf.reshape(out, [out_shape[0]*out_shape[1], out_shape[2], out_shape[3], out_shape[4]])
            with tf.variable_scope('%s_conv_1by1_1' % subnet, reuse=self.reuse) as scope:
                out, _ = self._conv(input=out, params=conv_1by1_1)
                out = tf.reshape(out, [out_shape[0], out_shape[1], out_shape[3], out_shape[4]])
                out = tf.transpose(out, perm=[1,0,2,3])
                # scopes_list.append(scope)
                print('conv1by1_decoder shape', out.shape.as_list())
        with tf.variable_scope('%s_conv_1by1_1024' % subnet, reuse=self.reuse) as scope:
            out, _ = self._conv(input=out, params=conv_1by1_1024) 
            do_bn = conv_1by1_1024.get('batch_norm', False)
            if do_bn:
                out = self._batch_norm(input=out)
            else:
                out = self._add_bias(input=out, params=conv_1by1)
            out = self._activate(input=out, params=conv_1by1) 
        outputs['IM'] = out
        self.model_output['decoder'] = outputs
        # Split last conv layer to predict amplitude and phase

                out = self._add_bias(input=out, params=conv_1by1_1024)
            out = self._activate(input=out, params=conv_1by1_1024)
            # scopes_list.append(scope)

        self._build_branch(subnet=subnet, inputs=out)

        # conv_1by1 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 'features': 1,
        #                     'activation': 'relu', 'padding': 'SAME', 'batch_norm': False})
        # with tf.variable_scope('%s_CONV_FIN' % subnet, reuse=self.reuse) as scope:
        #     out, _ = self._conv(input=self.model_output[subnet], params=conv_1by1)
        #     do_bn = conv_1by1.get('batch_norm', False)
        #     if do_bn:
        #         out = self._batch_norm(input=out)
        #     else:
        #         out = self._add_bias(input=out, params=conv_1by1)
        #     out = self._activate(input=out, params=conv_1by1)
        #     scopes_list.append(scope)
        # self.model_output[subnet] = out
        # self.all_scopes[subnet] += scopes_list

    def build_inverter(self):
        out = self.model_output['encoder']
        # out = self.fully_connected_block(out, self.network['encoder']['fully_connected_block'], 'inverter')
        # out = tf.transpose(out, perm=[0, 2, 1])
        # dim = int(math.sqrt(self.images.shape.as_list()[1]))
        # out = tf.reshape(out, [self.params['batch_size'], -1, dim, dim])
        conv_1by1 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
        params = self.network['encoder']['freq2space']['cvae_params']
        fully_connected = params['fc_params']
        num_fc = params['n_fc_layers']
        scopes_list = []
        if True:
            def fc_map(tens):
                for i in range(num_fc):
                    with tf.variable_scope('Inverter_fc_%d' %i, reuse=self.reuse) as scope :
                        tens = self._linear(input=tens, params=fully_connected)
                        tens = self._activate(input=tens, params=fully_connected)
                        # scopes_list.append(scope)
                return tens
            out = tf.map_fn(fc_map, out, back_prop=True, swap_memory=True, parallel_iterations=256)
            out = tf.transpose(out, perm= [1, 2, 0])
            dim = int(math.sqrt(self.images.shape.as_list()[1]))
            out = tf.reshape(out, [self.params['batch_size'], -1, dim, dim])
        else:
            conv_1by1_1 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
                                'features': 1,
                                'activation': 'relu', 
                                'padding': 'VALID', 
                                'batch_norm': False, 'dropout':0.0})
            out_shape = out.shape.as_list()
            out = tf.reshape(out, [out_shape[0]*out_shape[1], out_shape[2], out_shape[3], out_shape[4]])
            with tf.variable_scope('%s_conv_1by1_1' % 'inverter', reuse=self.reuse) as scope:
                out, _ = self._conv(input=out, params=conv_1by1_1)
                out = tf.reshape(out, [out_shape[0], out_shape[1], out_shape[3], out_shape[4]])
                out = tf.transpose(out, perm=[1,0,2,3])
                scopes_list.append(scope)
                print('conv1by1_inverter shape', out.shape.as_list())
        conv_1by1_1024 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
            'features': 1024,
            'activation': 'relu', 
            'padding': 'VALID', 
            'batch_norm': False, 'dropout':0.0})

        with tf.variable_scope('inverter_conv_1by1', reuse=self.reuse) as _:
            out, _ = self._conv(input=out, params=conv_1by1) 
            do_bn = conv_1by1.get('batch_norm', False)
        with tf.variable_scope('inverter_conv_1by1_1024', reuse=self.reuse) as scope:
            out, _ = self._conv(input=out, params=conv_1by1_1024) 
            do_bn = conv_1by1_1024.get('batch_norm', False)
            if do_bn:
                out = self._batch_norm(input=out)
            else:
                out = self._add_bias(input=out, params=conv_1by1)
            out = self._activate(input=out, params=conv_1by1)
        # fc_params = OrderedDict({'type': 'fully_connected','weights': 256,'bias': 256, 'activation': 'relu',
        #                            'regularize': True})
        # num_fc = self.network['inverter']['freq2space']['n_fc_layers']
        # num_fc =1
        # for i in range(num_fc):
        #     with tf.variable_scope('inverter_freq2space_fc_%d' %i, reuse=self.reuse) as _ :
        #         out = self._linear(input=out, params=fc_params)
        #         out = self._activate(input=out, params=fc_params)
        # dim = int(math.sqrt(fc_params['weights']))
        # out = tf.reshape(out, [self.params['batch_size'], 1, dim, dim])
                out = self._add_bias(input=out, params=conv_1by1_1024)
            out = self._activate(input=out, params=conv_1by1_1024)
            # scopes_list.append(scope)
        self._build_branch(subnet='inverter', inputs=out)
        self.all_scopes['inverter'] += scopes_list

    def build_model(self):
        self.build_encoder()
        self.build_decoder()
        self.build_decoder(subnet='decoder_IM')
        self.build_decoder(subnet='decoder_RE')
        self.build_inverter()
        self.scopes = list(chain.from_iterable([scope for _, scope in self.all_scopes.items()]))
        # self.scopes = list(chain.from_iterable([scope for _, scope in self.all_scopes.items()]))
        # self.scopes = None
        for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
            self.print_rank("var:%s , dtype:%s" % (var.name, var.dtype))
        ##TODO: check why one of the GRADIENTS COMES OUT AS INT32 WHEN DOING BATCHNORM

    
+40 −82

File changed.

Preview size limit exceeded, changes collapsed.

+2 −2
Original line number Diff line number Diff line
@@ -180,7 +180,7 @@ def reduce_gradients(grads_and_vars, on_horovod, model=None, run_params=None):
      with tf.name_scope("all_reduce"):
        for idx, (grad, var) in enumerate(grads_and_vars):
          if grad is not None:
            print("rank: %d, grad: %s, var:%s" %(rank(), grad.name, var.name))
            # print("rank: %d, grad: %s, var:%s" %(rank(), grad.name, var.name))
            avg_grad = allreduce(grad)
            averaged_grads_and_vars.append((avg_grad, var))
          else:
@@ -454,7 +454,7 @@ def post_process_gradients(grads_and_vars, summaries, lr,
      if len(ind_list) >= 1:
          layer_grads = [grads_and_vars[ind][0] for ind in ind_list]
          layer_vars = [grads_and_vars[ind][1] for ind in ind_list]
          grad_vec = tf.concat([tf.expand_dims(tf.reshape(grad, [-1]), 0) for grad in layer_grads], 1)
          grad_vec = tf.concat([tf.expand_dims(tf.reshape(tf.cast(grad, tf.float32), [-1]), 0) for grad in layer_grads], 1)
          var_vec = tf.concat([tf.expand_dims(tf.reshape(var, [-1]), 0) for var in layer_vars], 1)
          var_dtype = layer_vars[0].dtype
          var_nom = tf.norm(tensor=tf.cast(var_vec, tf.float32))
+5 −5
Original line number Diff line number Diff line
@@ -309,11 +309,11 @@ def train(network_config, hyper_params, params):
            tf.summary.image("inputs", tf.transpose(tf.reduce_mean(images, axis=1, keepdims=True), perm=[0,2,3,1]), max_outputs=4)
        elif hyper_params['network_type'] == 'YNet': 
            predic_inverter = tf.transpose(n_net.model_output['inverter'], perm=[0,2,3,1])
            tf.summary.image("output_inverter", predic_inverter, max_outputs=4) 
            predic_decoder_RE = tf.transpose(n_net.model_output['decoder']['RE'], perm=[0,2,3,1])
            predic_decoder_IM = tf.transpose(n_net.model_output['decoder']['IM'], perm=[0,2,3,1])
            tf.summary.image("output_decoder_RE", predic_decoder_RE, max_outputs=4)
            tf.summary.image("output_decoder_IM", predic_decoder_IM, max_outputs=4)
            tf.summary.image("output_inverter", predic_inverter, max_outputs=2) 
            predic_decoder_RE = tf.transpose(n_net.model_output['decoder_RE'], perm=[0,2,3,1])
            predic_decoder_IM = tf.transpose(n_net.model_output['decoder_IM'], perm=[0,2,3,1])
            tf.summary.image("output_decoder_RE", predic_decoder_RE, max_outputs=2)
            tf.summary.image("output_decoder_IM", predic_decoder_IM, max_outputs=2)
            new_labels = tf.unstack(labels, axis=1)
            for label, tag in zip(new_labels, ['potential', 'probe_RE', 'probe_IM']):
                label = tf.expand_dims(label, axis=-1)