Commit 81ddd819 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

minor changes to validate function for ynet

parent 26de6b23
Loading
Loading
Loading
Loading
Loading
+2 −4
Original line number Diff line number Diff line
@@ -108,11 +108,9 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
    #Assemble all of the losses.
    losses = tf.get_collection(tf.GraphKeys.LOSSES)
    if hyper_params['network_type'] == 'YNet':
        losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_loss]
        # losses = [inverter_loss , decoder_loss_re, decoder_loss_im]
        reg_str = hyper_params.get('reg_strength', 0.1)
        losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_str * reg_loss]
        # losses, prefac = ynet_adjusted_losses(losses, step)
        # tf.summary.scalar("prefac_inverter", prefac)
        # losses = [inverter_loss]
    regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    # Calculate the total loss 
    total_loss = tf.add_n(losses + regularization, name='total_loss')
+10 −2
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ def decay_warmup(params, hyper_params, global_step):
    # Decay/ramp the learning rate exponentially based on the number of steps.
    def ramp():
        lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, ramp_steps, LEARNING_RATE_DECAY_FACTOR,
                                        staircase=False)
                                        staircase=True)
        lr = INITIAL_LEARNING_RATE ** 2 * tf.pow(lr, tf.constant(-1.))
        lr = tf.minimum(lr,WARM_UP_LEARNING_RATE_MAX)
        return lr
@@ -51,9 +51,17 @@ def decay_warmup(params, hyper_params, global_step):

        return lr
    
    warm_up_slope = hyper_params.get('warm_up_slope', 1.)
    def constant_ramp():
        lr = tf.cast(INITIAL_LEARNING_RATE, tf.float32) * (tf.cast(global_step, tf.float32) * warm_up_slope  + 1) 
        lr = tf.math.minimum(tf.cast(WARM_UP_LEARNING_RATE_MAX, tf.float32), lr)
        return lr

    if hyper_params['warm_up']:
        # LEARNING_RATE = tf.cond(global_step < ramp_up_steps, ramp, lambda: decay(ramp()))
        LEARNING_RATE = tf.cond(global_step < ramp_up_steps, linear_ramp, lambda: decay(linear_ramp()))
        #LEARNING_RATE = tf.cond(global_step < ramp_up_steps, linear_ramp, lambda: decay(linear_ramp()))
        ramp_up_steps = tf.cast(WARM_UP_LEARNING_RATE_MAX/INITIAL_LEARNING_RATE, global_step.dtype) 
        LEARNING_RATE = tf.cond(global_step < ramp_up_steps, constant_ramp, lambda: decay(constant_ramp()))
    else:
        LEARNING_RATE = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps,
                                        LEARNING_RATE_DECAY_FACTOR, staircase=True) 
+3 −2
Original line number Diff line number Diff line
@@ -1247,7 +1247,7 @@ class FCDenseNet(ConvNet):
    
    def _freq2space(self, inputs=None):
        shape = inputs.shape
        weights_dim = 512
        weights_dim = 256
        num_fc = 2
        # if weights_dim < 4096 :
        fully_connected = OrderedDict({'type': 'fully_connected','weights': weights_dim,'bias': weights_dim, 'activation': 'relu',
@@ -1266,7 +1266,8 @@ class FCDenseNet(ConvNet):
                out = tf.reshape(out, [shape[0], -1])
            for i in range(num_fc):
                if i > 0:
                    weights_dim = min(4096, int(shape.as_list()[-2]**2))
                    #weights_dim = min(4096, int(shape.as_list()[-2]**2))
                    weights_dim = min(weights_dim, int(shape.as_list()[-2]**2))
                    fully_connected['weights'] = weights_dim
                    fully_connected['bias'] = weights_dim
                with tf.variable_scope('FC_%d' %i, reuse=self.reuse) as _ :
+11 −7
Original line number Diff line number Diff line
@@ -410,8 +410,8 @@ def train(network_config, hyper_params, params, gpu_id=None):
            val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
            val_results.append((train_elf.last_step,val))
        if doFinish: 
            val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
            val_results.append((train_elf.last_step, val))
            #val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
            #val_results.append((train_elf.last_step, val))
            tf.reset_default_graph()
            tf.keras.backend.clear_session()
            sess.close()
@@ -518,9 +518,13 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        print('not implemented')
    elif hyper_params['network_type'] == 'YNet':
        loss_params = hyper_params['loss_function']
        model_output = tf.concat([n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']], axis=1)
        #model_output = tf.concat([n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']], axis=1)
        model_output = [n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']]
        labels = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
        if loss_params['type'] == 'MSE_PAIR':
            errors = tf.losses.mean_pairwise_squared_error(tf.cast(labels, tf.float32), tf.cast(model_output, tf.float32))
            errors = [tf.losses.mean_pairwise_squared_error(tf.cast(label, tf.float32), out) 
                                                for label, out in zip(labels, model_output)]
            errors = tf.stack(errors)
            loss_label= loss_params['type'] 
        elif loss_params['type'] == 'ABS_DIFF': 
            loss_label= 'ABS_DIFF'
@@ -535,8 +539,8 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        elif num_batches > dset.num_samples:
            num_samples = dset.num_samples
        errors = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples//params['batch_size'])])
        result = errors.mean()
        print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
        result = errors.mean(0)
        print_rank('Validation Reconstruction Error %s: '% loss_label, result)
    elif hyper_params['network_type'] == 'inverter':
        loss_params = hyper_params['loss_function']
        if labels.shape.as_list()[1] > 1:
@@ -563,7 +567,7 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
            num_samples = dset.num_samples
        errors = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples//params['batch_size'])])
        result = errors.mean()
        print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
        print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, result))
        tf.summary.scalar("Validation_loss_label_%s" % loss_label, tf.constant(errors.mean()))
    return result