Commit 6ada3103 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

adding freq2space module to fcdensenet, functions for training ynet, and minor...

adding freq2space module to fcdensenet, functions for training ynet, and minor changes to training loop
parent 8fa3f4b6
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -528,9 +528,9 @@ class DatasetLMDB(DatasetTFRecords):
                for _ in range(self.params['batch_size']):
                    image, label = iterator.get_next()
                    image = tf.reshape(image, self.data_specs['image_shape'])
                    if self.params[self.mode + '_distort']:
                        image = self.add_noise_image(image)
                    images.append(image)
                    # if self.params[self.mode + '_distort']:
                        # image = self.add_noise_image(image)
                    images.append(tf.reshape(image, self.data_specs['image_shape']))
                    labels.append(tf.reshape(label, self.data_specs['label_shape']))
            if tf.executing_eagerly():
                images = tf.stack(images)
+36 −4
Original line number Diff line number Diff line
@@ -61,8 +61,12 @@ def calc_loss(n_net, scope, hyper_params, params, labels, images=None, summary=F
        #     labels = tf.transpose(labels, perm=[0, 2, 3, 1])
        #     labels = tf.image.resize(labels, n_net.model_output.shape.as_list()[-2:], method=tf.image.ResizeMethod.BILINEAR)
        #     labels = tf.transpose(labels, perm=[0, 3, 1, 2])
        if labels_shape[1] > 1:
            pot_labels, _, _ = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
        else:
            pot_labels = labels
        weight=None
        _ = calculate_loss_regressor(n_net.model_output, labels, params, hyper_params, weight=weight)
        _ = calculate_loss_regressor(n_net.model_output, pot_labels, params, hyper_params, weight=weight)
    if hyper_params['network_type'] == 'fft_inverter':
        n_net.model_output = tf.exp(1.j * tf.cast(n_net.model_output, tf.complex64))
        psi_pos = tf.ones([1,16,512,512], dtype=tf.complex64)
@@ -101,10 +105,8 @@ def calc_loss(n_net, scope, hyper_params, params, labels, images=None, summary=F
    #Assemble all of the losses.
    losses = tf.get_collection(tf.GraphKeys.LOSSES)
    if hyper_params['network_type'] == 'YNet':
        # losses = [inverter_loss, decoder_loss_re]
        losses = [inverter_loss , decoder_loss_re, decoder_loss_im ]
        # losses = [inverter_loss , decoder_loss_im ]
    # losses = [inverter_loss]
        losses = ynet_adjusted_losses(losses, tf.train.get_or_create_global_step())
    regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    # Calculate the total loss 
    total_loss = tf.add_n(losses + regularization, name='total_loss')
@@ -204,3 +206,33 @@ def calculate_loss_regressor(net_output, labels, params, hyper_params, weight=No
    if loss_params['type'] == 'LOG':
        cost = tf.losses.log_loss(labels, weights=weight, predictions=net_output, reduction=tf.losses.Reduction.MEAN)
    return cost


def ynet_adjusted_losses(losses, global_step):
    '''
    Schedule the different loss components based on global training step
    '''
    threshold = 0.8
    max_prefac = 20
    ema = tf.train.ExponentialMovingAverage(0.99)
    loss_averages_op = ema.apply(losses)

    with tf.control_dependencies([loss_averages_op]):
        def ramp():
            prefac = tf.cast(tf.train.exponential_decay(tf.constant(1.), global_step, 2 * global_step, 
                            tf.constant(0.5, dtype=tf.float32), staircase=False), tf.float32)
            prefac = 1 ** 2 * tf.pow(prefac, tf.constant(-1., dtype=tf.float32))
            prefac = tf.minimum(prefac, tf.cast(max_prefac, tf.float32))
            return prefac
    
        def decay(prefac_current):
            prefac = tf.train.exponential_decay(prefac_current, global_step, 2 * global_step, tf.constant(0.5, dtype=tf.float32),
                                            staircase=True)
            return prefac
    # inv_loss, dec_re_loss, dec_im_loss = [ema.average(tens.name) for tens in losses]
    inv_loss, dec_re_loss, dec_im_loss = losses 

    prefac  = tf.cond(inv_loss < threshold, ramp, lambda: decay(ramp()))
    tf.summary.scalar("prefac_inverter", prefac)
    losses = [prefac * (inv_loss - threshold), dec_re_loss, dec_im_loss]
    return losses
+151 −77

File changed.

Preview size limit exceeded, changes collapsed.

+90 −47
Original line number Diff line number Diff line
@@ -111,7 +111,6 @@ class TrainHelper(object):

    def log_stats(self, loss_value, learning_rate):
        self.nanloss(loss_value)
        if hvd.rank() == 0:
        t = time.time( )
        duration = t - self.start_time
        examples_per_sec = self.params['batch_size'] * hvd.size() / duration
@@ -358,11 +357,11 @@ def train(network_config, hyper_params, params):
    #     train_elf = TrainHelper(params, saver, None, n_net.get_ops(), last_step=last_step)

    train_elf = TrainHelper(params, saver, summary_writer,  n_net.get_ops(), last_step=last_step, log_freq=params['log_frequency'])
    if params['restart']:
        saveStep = train_elf.last_step + params['save_step']
        validateStep = train_elf.last_step + params['validate_step']
        summaryStep = train_elf.last_step + params['summary_step'] 
    else:
    # if params['restart']:
    #     saveStep = train_elf.last_step + params['save_step']
    #     validateStep = train_elf.last_step + params['validate_step']
    #     summaryStep = train_elf.last_step + params['summary_step'] 
    # else:
    saveStep =  params['save_step']
    validateStep = params['validate_step']
    summaryStep = params['summary_step']
@@ -371,54 +370,54 @@ def train(network_config, hyper_params, params):
    maxSteps  = params[ 'max_steps' ]
    logFreq   = params[ 'log_frequency' ]
    traceStep = params[ 'trace_step' ]
    maxTime = params.get('max_time', 1e12)
        
    while train_elf.last_step < maxSteps :
        train_elf.before_run()

        doLog   = train_elf.last_step % logFreq  == 0
        doSave  = train_elf.last_step >= saveStep 
        doSumm  = train_elf.last_step > summaryStep 
        doTrace = train_elf.last_step == traceStep and params['gpu_trace']
        doLog   = bool(train_elf.last_step % logFreq  == 0)
        doSave  = bool(train_elf.last_step % saveStep == 0)
        doSumm  = bool(train_elf.last_step % summaryStep == 0 and params['debug'])
        doTrace = bool(train_elf.last_step == traceStep and params['gpu_trace'])
        doValidate = bool(train_elf.last_step % validateStep == 0)
        doFinish = bool(train_elf.start_time - params['start_time'] > maxTime)
        if train_elf.last_step == 1 and params['debug']:
            pass
            # if hvd.rank() == 0:
            summary = sess.run([train_op,  summary_merged])[-1]
            train_elf.write_summaries( summary )
        if not doLog and not doSave and not doTrace and not doSumm:
        elif not doLog and not doSave and not doTrace and not doSumm:
            sess.run(train_op)
        elif doLog and not doSave :
        elif doLog and not doSave  and not doSumm:
            _, loss_value, lr = sess.run( [ train_op, total_loss, learning_rate ] )
            train_elf.log_stats( loss_value, lr )
        elif doLog and doSave :
        elif doLog and doSumm and doSave :
            _, summary, loss_value, lr = sess.run( [ train_op, summary_merged, total_loss, learning_rate ])
            train_elf.log_stats( loss_value, lr )
            train_elf.write_summaries( summary )
            if hvd.rank( ) == 0 :
                saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
                print_rank('Saved Checkpoint.')
            saveStep += params['save_step']
        elif doSumm and params['debug']:
            # if hvd.rank() == 0:
        elif doLog and doSumm :
            _, summary, loss_value, lr = sess.run( [ train_op, summary_merged, total_loss, learning_rate ])
            train_elf.log_stats( loss_value, lr )
            train_elf.write_summaries( summary )
        elif doSumm:
            summary = sess.run([train_op,  summary_merged])[-1]
            train_elf.write_summaries( summary )
            summaryStep += params['summary_step'] 
        elif doSave :
            #summary = sess.run([train_op,  summary_merged])[-1]
            #train_elf.write_summaries( summary )
            if hvd.rank( ) == 0 :
                saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
                print_rank('Saved Checkpoint.')
            saveStep += params['save_step']
        elif doTrace :
            sess.run(train_op, options=run_options, run_metadata=run_metadata)
            train_elf.save_trace(run_metadata, params[ 'trace_dir' ], params[ 'trace_step' ] )
            train_elf.before_run()
        # Here we do validation:
        #if train_elf.elapsed_epochs > next_validation_epoch:
        if train_elf.last_step > validateStep:
        if doValidate:
            validate(network_config, hyper_params, params, sess, dset)
            validateStep += params['validate_step']
            #next_validation_epoch += params['epochs_per_validation']
        if doFinish:
            saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
            print_rank('Saved Final Checkpoint.')
            return



def train_inverter(network_config, hyper_params, params):
@@ -754,6 +753,9 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        if params['network_class'] == 'fcnet':
            n_net = network.FCNet(scope, params, hyper_params, network_config, images, labels,
                                    operation='eval', summary=summary, verbose=True)
        if params['network_class'] == 'YNet':
            n_net = network.YNet(scope, params, hyper_params, network_config, images, labels,
                                        operation='eval', summary=summary, verbose=True)

        # Build it and propagate images through it.
        n_net.build_model()
@@ -762,7 +764,7 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
    if hyper_params['network_type'] == 'regressor' or hyper_params['network_type'] == 'classifier':
        labels_shape = labels.get_shape().as_list()
        layer_params={'bias':labels_shape[-1], 'weights':labels_shape[-1],'regularize':False}
        logits = fully_connected(n_net, layer_params, params['batch_size'],
        logits = losses.fully_connected(n_net, layer_params, params['batch_size'],
                                name='linear',reuse=None)
    else:
        pass
@@ -793,6 +795,8 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        print('not implemented')
    elif hyper_params['network_type'] == 'inverter':
        loss_params = hyper_params['loss_function']
        if labels.shape.as_list()[1] > 1:
            labels, _, _ = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)]
        if loss_params['type'] == 'MSE_PAIR':
            errors = tf.losses.mean_pairwise_squared_error(tf.cast(labels, tf.float32), tf.cast(n_net.model_output, tf.float32))
            loss_label= loss_params['type'] 
@@ -801,14 +805,19 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
            errors = tf.losses.absolute_difference(tf.cast(labels, tf.float32), tf.cast(n_net.model_output, tf.float32), reduction=tf.losses.Reduction.MEAN)
        errors = tf.expand_dims(errors,axis=0)
        error_averaging = hvd.allreduce(errors)
        errors = np.array([sess.run([IO_ops,errors])[-1] for i in range(dset.num_samples // params['batch_size'])])

        if num_batches is not None:
            num_samples = num_batches
        else:
            num_samples = dset.num_samples
        errors = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples)])
        # errors = np.array([sess.run([IO_ops,errors])[-1] for i in range(dset.num_samples)])
        # errors = tf.reduce_mean(errors)
        # avg_errors = hvd.allreduce(tf.expand_dims(errors, axis=0))
        # error = sess.run(avg_errors)
        # print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
        print('rank=%d, Validation Reconstruction Error %s: %3.3e' % (hvd.rank(), loss_label, errors.mean()))
        tf.summary.scalar("Validation_loss_label_%s" %hvd.rank() , tf.constant(errors.mean()))
        print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
        tf.summary.scalar("Validation_loss_label_%s" % loss_label, tf.constant(errors.mean()))


def validate_ckpt(network_config, hyper_params, params, num_batches=None,
@@ -869,6 +878,9 @@ def validate_ckpt(network_config, hyper_params, params, num_batches=None,
            if params['network_class'] == 'fcnet':
                n_net = network.FCNet(scope, params, hyper_params, network_config, images, labels,
                                        operation='eval_ckpt', summary=False, verbose=True)
            if params['network_class'] == 'YNet':
                n_net = network.YNet(scope, params, hyper_params, network_config, images, labels,
                                        operation='eval_ckpt', summary=False, verbose=True)
            # Build it and propagate images through it.
            n_net.build_model()

@@ -969,6 +981,37 @@ def validate_ckpt(network_config, hyper_params, params, num_batches=None,
                        num_samples = dset.num_samples
                    error = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples)])
                    print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, error.mean()))
            elif hyper_params['network_type'] == 'YNet':
                loss_params = hyper_params['loss_function']
                model_output = tf.stack([n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']], axis=0)
                if params['output']:
                    output = tf.cast(model_output, tf.float32)
                    print('output shape',output.get_shape().as_list()) 
                    output_dir = os.path.join(os.getcwd(),'outputs')
                    if num_batches is not None:
                        num_samples = num_batches
                    else:
                        num_samples = dset.num_samples
                    for idx in range(num_samples):
                        output_arr, label_arr = sess.run([IO_ops, model_output, labels])[-2:]
                        #label_arr = sess.run([IO_ops, labels])[-1]
                        np.save(os.path.join(output_dir,'label_%d_%d_%s.npy' % (idx, hvd.rank(), format(last_step))), label_arr)
                        np.save(os.path.join(output_dir,'output_%d_%d_%s.npy' % (idx, hvd.rank(), format(last_step))), output_arr)
                else:
                    if loss_params['type'] == 'MSE_PAIR':
                        errors = tf.losses.mean_pairwise_squared_error(tf.cast(labels, tf.float32), tf.cast(model_output, tf.float32))
                        loss_label= loss_params['type'] 
                    else: 
                        loss_label= 'ABS_DIFF'
                        errors = tf.losses.absolute_difference(tf.cast(labels, tf.float32), tf.cast(model_output, tf.float32), reduction=tf.losses.Reduction.MEAN)
                    errors = tf.expand_dims(errors,axis=0)
                    error_averaging = hvd.allreduce(errors)
                    if num_batches is not None:
                        num_samples = num_batches
                    else:
                        num_samples = dset.num_samples
                    error = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples)])
                    print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, error.mean()))
            if sleep < 0:
                break
            else: