Commit 691cc156 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

adding masking in loss calculation of ynet


Former-commit-id: 6c3feb3a
parent 9f6afb1c
Loading
Loading
Loading
Loading
+20 −13
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ def _add_loss_summaries(total_loss, losses, summaries=False):
    return loss_averages_op


def calc_loss(n_net, scope, hyper_params, params, labels, images=None, summary=False):
def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None, summary=False):
    labels_shape = labels.get_shape().as_list()
    layer_params={'bias':labels_shape[-1], 'weights':labels_shape[-1],'regularize':True}
    if hyper_params['network_type'] == 'hybrid':
@@ -85,6 +85,12 @@ def calc_loss(n_net, scope, hyper_params, params, labels, images=None, summary=F
        # weight=0.10
        inverter_loss = calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight)
        # weight=1
        # probe_shape = probe_labels_re.shape.as_list()
        # mask = np.ones(probe_shape, dtype=np.float32)
        # snapshot = slice(probe_shape[-1]// 4, 3 * probe_shape[-1]//4)
        # mask[:,:, snapshot, snapshot] = 100.0
        # #mask = np.expand_dims(np.expand_dims(mask, axis=0), axis=0)
        # weight = tf.constant(mask)1
        decoder_loss_im = calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight)
        decoder_loss_re = calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight)
        tf.summary.scalar('Inverter loss (raw)', inverter_loss)
@@ -106,7 +112,8 @@ def calc_loss(n_net, scope, hyper_params, params, labels, images=None, summary=F
    losses = tf.get_collection(tf.GraphKeys.LOSSES)
    if hyper_params['network_type'] == 'YNet':
        losses = [inverter_loss , decoder_loss_re, decoder_loss_im ]
        losses = ynet_adjusted_losses(losses, tf.train.get_or_create_global_step())
        # losses, prefac = ynet_adjusted_losses(losses, step)
        # tf.summary.scalar("prefac_inverter", prefac)
    regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    # Calculate the total loss 
    total_loss = tf.add_n(losses + regularization, name='total_loss')
@@ -212,27 +219,27 @@ def ynet_adjusted_losses(losses, global_step):
    '''
    Schedule the different loss components based on global training step
    '''
    threshold = 0.8
    threshold = tf.constant(0.25)
    max_prefac = 20
    ema = tf.train.ExponentialMovingAverage(0.99)
    ema = tf.train.ExponentialMovingAverage(0.9)
    loss_averages_op = ema.apply(losses)
    prefac_initial = 2.0

    with tf.control_dependencies([loss_averages_op]):
        def ramp():
            prefac = tf.cast(tf.train.exponential_decay(tf.constant(1.), global_step, 2 * global_step, 
            prefac = tf.cast(tf.train.exponential_decay(tf.constant(prefac_initial), global_step, 10000, 
                            tf.constant(0.5, dtype=tf.float32), staircase=False), tf.float32)
            prefac = 1 ** 2 * tf.pow(prefac, tf.constant(-1., dtype=tf.float32))
            prefac = tf.constant(prefac_initial) ** 2 * tf.pow(prefac, tf.constant(-1., dtype=tf.float32))
            prefac = tf.minimum(prefac, tf.cast(max_prefac, tf.float32))
            return prefac
    
        def decay(prefac_current):
            prefac = tf.train.exponential_decay(prefac_current, global_step, 2 * global_step, tf.constant(0.5, dtype=tf.float32),
            prefac = tf.train.exponential_decay(prefac_current, global_step, 1000, tf.constant(0.5, dtype=tf.float32),
                                            staircase=True)
            return prefac
    # inv_loss, dec_re_loss, dec_im_loss = [ema.average(tens.name) for tens in losses]
        inv_loss, dec_re_loss, dec_im_loss = losses 

    prefac  = tf.cond(inv_loss < threshold, ramp, lambda: decay(ramp()))
        prefac  = tf.cond(inv_loss > threshold, true_fn=ramp, false_fn=lambda: decay(ramp()))
        tf.summary.scalar("prefac_inverter", prefac)
        losses = [prefac * (inv_loss - threshold), dec_re_loss, dec_im_loss]
    return losses
        return losses, prefac
+3 −9
Original line number Diff line number Diff line
@@ -255,7 +255,6 @@ class ConvNet:
            'regularize':True}
        with tf.variable_scope('linear_output', reuse=self.reuse) as scope:
            output = self._linear(input=self.model_output, name=scope.name, params=layer_params)
            print(output.name)
        if self.params['IMAGE_FP16']:
            output = tf.cast(output, tf.float32)
            return output
@@ -288,7 +287,6 @@ class ConvNet:
                'regularize':True}
            with tf.variable_scope(layer_name, reuse=self.reuse) as scope:
                out = tf.cast(self._linear(input=self.model_output, name=scope.name, params=layer_params), tf.float32)
                print(out.name)
            self.print_rank('Output Layer : %s' %format(out.get_shape().as_list()))
            outputs.append(out)
        mixing = self.hyper_params['mixing']
@@ -664,8 +662,6 @@ class ConvNet:
                lin_initializer.factor = 1.0
        elif isinstance(lin_initializer, tf.random_normal_initializer):
            init_val = max(np.sqrt(2.0 / params['weights']), 0.01)
            if verbose:
                print('stddev: %s' % format(init_val))
            lin_initializer.mean = 0.0
            lin_initializer.stddev = init_val

@@ -2646,7 +2642,7 @@ class YNet(FCDenseNet, FCNet):

        # post_ops = deepcopy(self.ops)
        # self.print_rank("post pre, cvae ops: ", pre_ops - post_ops)
        out = tf.map_fn(CVAE, tensor_slices, back_prop=True, swap_memory=True, parallel_iterations=256)
        out = tf.map_fn(CVAE, tensor_slices, back_prop=True)
        # self.print_rank('output of CVAE', out.get_shape())
        # out = tf.transpose(out, perm= [1, 2, 0])
        # dim = int(math.sqrt(self.images.shape.as_list()[1]))
@@ -2873,7 +2869,7 @@ class YNet(FCDenseNet, FCNet):
        'activation': 'relu', 
        'padding': 'VALID', 
        'batch_norm': True, 'dropout':0.0})
        if False:
        if True:
            def fc_map(tens):
                for i in range(num_fc):
                    with tf.variable_scope('%s_fc_%d' %(subnet, i), reuse=self.reuse) as scope :
@@ -2881,7 +2877,7 @@ class YNet(FCDenseNet, FCNet):
                        tens = self._activate(input=tens, params=fully_connected)
                        # scopes_list.append(scope)
                return tens
            out = tf.map_fn(fc_map, out, back_prop=True, swap_memory=True, parallel_iterations=256)
            out = tf.map_fn(fc_map, out, back_prop=True)
            out = tf.transpose(out, perm= [1, 2, 0])
            dim = int(math.sqrt(self.images.shape.as_list()[1]))
            out = tf.reshape(out, [self.params['batch_size'], -1, dim, dim])
@@ -2893,7 +2889,6 @@ class YNet(FCDenseNet, FCNet):
                out = tf.transpose(out, perm=[1,0,2,3])

                # scopes_list.append(scope)
                print('conv1by1_decoder shape', out.shape.as_list())
        with tf.variable_scope('%s_conv_1by1_1024' % subnet, reuse=self.reuse) as scope:
            out, _ = self._conv(input=out, params=conv_1by1_1024) 
            out = self._activate(input=out, params=conv_1by1_1024)
@@ -2951,7 +2946,6 @@ class YNet(FCDenseNet, FCNet):
                out = tf.reshape(out, [out_shape[0], out_shape[1], out_shape[3], out_shape[4]])
                out = tf.transpose(out, perm=[1,0,2,3])
                scopes_list.append(scope)
                print('conv1by1_inverter shape', out.shape.as_list())
        conv_1by1_1024 = OrderedDict({'type': 'conv_2D', 'stride': [1, 1], 'kernel': [1, 1], 
            'features': 1024,
            'activation': 'relu', 
+1 −1
Original line number Diff line number Diff line
@@ -250,7 +250,7 @@ def train(network_config, hyper_params, params):
            n_net.build_model()

            # calculate the total loss
            total_loss, loss_averages_op = losses.calc_loss(n_net, scope, hyper_params, params, labels, images=images, summary=summary)
            total_loss, loss_averages_op = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary)

            #get summaries, except for the one produced by string_input_producer
            if summary: summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)