Commit 4034f783 authored by Nouamane Laanait's avatar Nouamane Laanait
Browse files

minor bugs


Former-commit-id: 03974ced
parent ff3471f2
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -13,9 +13,9 @@ from tensorflow.python.ops import data_flow_ops
import horovod.tensorflow as hvd
import lmdb
import time
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as dali_ops
import nvidia.dali.plugin.tf as dali_tf
#from nvidia.dali.pipeline import Pipeline
#import nvidia.dali.ops as dali_ops
#import nvidia.dali.plugin.tf as dali_tf

tf.logging.set_verbosity(tf.logging.ERROR)

+12 −12
Original line number Diff line number Diff line
@@ -5,14 +5,14 @@ email: laanaitn@ornl.gov
"""
from collections import OrderedDict
import json
import horovod.tensorflow as hvd

# JSON utility functions

# import horovod.tensorflow as hvd

# def print(self, *args, **kwargs):
#    if hvd.rank() == 0 :
    #    print(*args, **kwargs)
def print_rank(self, *args, **kwargs):
    if hvd.rank() == 0 :
        print_rank(*args, **kwargs)

def write_json_network_config(file, layer_keys, layer_params):
    """
@@ -26,7 +26,7 @@ def write_json_network_config(file, layer_keys, layer_params):
    network_config = OrderedDict(zip(layer_keys, layer_params))
    with open(file, mode='w') as f:
        json.dump(network_config, f, indent=4)
    print('Wrote %d NN layers to %s' % (len(network_config.keys()), file))
    print_rank('Wrote %d NN layers to %s' % (len(network_config.keys()), file))


def load_json_network_config(file):
@@ -46,7 +46,7 @@ def load_json_network_config(file):
        output = json.load(f, object_hook=_as_ordered_dict, object_pairs_hook=_as_ordered_dict)
        network_config = OrderedDict(output)

    print('Read %d NN layers from %s' % (len(network_config.keys()), file))
    print_rank('Read %d NN layers from %s' % (len(network_config.keys()), file))
    return network_config


@@ -60,7 +60,7 @@ def write_json_hyper_params(file, hyper_params):

    with open(file, mode='w') as f:
        json.dump(hyper_params, f, indent=4)
    print('Wrote %d hyperparameters to %s' % (len(hyper_params.keys()), file))
    print_rank('Wrote %d hyperparameters to %s' % (len(hyper_params.keys()), file))


def load_json_hyper_params(file):
@@ -72,7 +72,7 @@ def load_json_hyper_params(file):
    with open(file, mode='r') as f:
        hyper_params = json.load(f)

    print('Read %d hyperparameters from %s' % (len(hyper_params.keys()), file))
    print_rank('Read %d hyperparameters from %s' % (len(hyper_params.keys()), file))
    return hyper_params


@@ -81,7 +81,7 @@ def load_flags_from_simple_json(file_path, flags, verbose=False):
    for parm_name in image_parms.keys():
        val = image_parms[parm_name]
        if verbose:
            print('\t{}: {}'.format(parm_name, val))
            print_rank('\t{}: {}'.format(parm_name, val))
        if isinstance(val, bool):
            dtype = 'boolean'
            func = flags.DEFINE_boolean
@@ -97,7 +97,7 @@ def load_flags_from_simple_json(file_path, flags, verbose=False):
        else:
            raise NotImplemented('{} : {} of type that we cannot handle now'.format(parm_name, val))
        if verbose:
            print('{} : {} saved as {}'.format(parm_name, val, dtype))
            print_rank('{} : {} saved as {}'.format(parm_name, val, dtype))
        func(parm_name, val, """""")


@@ -114,7 +114,7 @@ def load_flags_from_json(file_path, flags, verbose=False):
    image_parms = load_json_hyper_params(file_path)
    for parm_name, parm_values in list(image_parms.items()):
        if verbose:
            print('\t{}: {}'.format(parm_name, parm_values))
            print_rank('\t{}: {}'.format(parm_name, parm_values))
        if parm_values['type'] == 'bool':
            func = flags.DEFINE_boolean
        elif parm_values['type'] == 'int':
@@ -126,6 +126,6 @@ def load_flags_from_json(file_path, flags, verbose=False):
        else:
            raise NotImplemented('Cannot handle type: {} for parameter: {}'.format(parm_values['type'], parm_name))
        if verbose:
            print('{} : {} saved as {} with description: {}'.format(parm_name, parm_values['value'],
            print_rank('{} : {} saved as {} with description: {}'.format(parm_name, parm_values['value'],
                                                                    parm_values['type'], parm_values['desc']))
        func(parm_name, parm_values['value'], parm_values['desc'])
+9 −1
Original line number Diff line number Diff line
@@ -191,7 +191,7 @@ def calculate_loss_regressor(net_output, labels, params, hyper_params, weight=No
        global_step = 1
    loss_params = hyper_params['loss_function']
    assert loss_params['type'] == 'Huber' or loss_params['type'] == 'MSE' \
    or loss_params['type'] == 'LOG' or loss_params['type'] == 'MSE_PAIR' or loss_params['type'] == 'ABS_DIFF' or loss_params['type'] == 'ABS_DIFF_SCALED', "Type of regression loss function must be 'Huber' or 'MSE'"
    or loss_params['type'] == 'LOG' or loss_params['type'] == 'MSE_PAIR' or loss_params['type'] == 'ABS_DIFF' or loss_params['type'] == 'ABS_DIFF_SCALED' or loss_params['type'] == 'rMSE', "Type of regression loss function must be 'Huber' or 'MSE'"
    if loss_params['type'] == 'Huber':
        # decay the residual cutoff exponentially
        decay_steps = int(params['NUM_EXAMPLES_PER_EPOCH'] / params['batch_size'] \
@@ -220,6 +220,14 @@ def calculate_loss_regressor(net_output, labels, params, hyper_params, weight=No
                                            #reduction=tf.losses.Reduction.SUM)
    if loss_params['type'] == 'MSE_PAIR':
        cost = tf.losses.mean_pairwise_squared_error(labels, net_output, weights=weight)
    if loss_params['type'] == 'rMSE':
        labels = tf.cast(labels, tf.float32)
        l2_true = tf.sqrt(tf.reduce_sum(labels ** 2, axis=[1,2,3]))
        l2_output = tf.sqrt(tf.reduce_sum(net_output **2, axis = [1,2,3]))
        cost = tf.reduce_mean(tf.abs(l2_true - l2_output)/l2_true)
        #cost = tf.reduce_mean(tf.sqrt(tf.reduce_sum((labels - net_output)**2, axis=[1,2,3]))/tf.sqrt(tf.reduce_sum(labels**2, axis=[1,2,3])))
        cost *= 100
        tf.add_to_collection(tf.GraphKeys.LOSSES, cost)
    if loss_params['type'] == 'LOG':
        cost = tf.losses.log_loss(labels, weights=weight, predictions=net_output, reduction=tf.losses.Reduction.MEAN)
    return cost
+2 −1
Original line number Diff line number Diff line
@@ -2233,7 +2233,8 @@ class YNet(FCDenseNet, FCNet):
        self.network = dict([(key, itm) for key,itm in self.network.items()])

    def _batch_norm(self, input=None):
        out = tf.keras.layers.BatchNormalization(axis=1)(inputs=input, training= self.operation == 'train')
        #out = tf.keras.layers.BatchNormalization(axis=1)(inputs=input, training= self.operation == 'train')
        out = super(YNet, self)._batch_norm(input=input)
        return out

    def get_all_ops(self, subnet=None):
+42 −10
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ import tensorflow as tf
from collections import OrderedDict
import horovod.tensorflow as hvd
from tensorflow.python.client import timeline
from tensorflow.contrib.compiler import xla
#from tensorflow.contrib.compiler import xla

# stemdl
from . import network
@@ -152,8 +152,8 @@ def train(network_config, hyper_params, params):
    config.gpu_options.force_gpu_compatible = True
    config.intra_op_parallelism_threads = 6 
    config.inter_op_parallelism_threads = max(1, cpu_count()//6)
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
    #config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    #jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
    # JIT causes gcc errors on dgx-dl and is built without on Summit.
    sess = tf.Session(config=config)

@@ -786,6 +786,24 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
    elif hyper_params['network_type'] == 'hybrid':
        #TODO: implement evaluation call for hybrid network
        print('not implemented')
    elif hyper_params['network_type'] == 'YNet':
        loss_params = hyper_params['loss_function']
        model_output = tf.concat([n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']], axis=1)
        if loss_params['type'] == 'MSE_PAIR':
            errors = tf.losses.mean_pairwise_squared_error(tf.cast(labels, tf.float32), tf.cast(model_output, tf.float32))
            loss_label= loss_params['type'] 
        else: 
            loss_label= 'ABS_DIFF'
            errors = tf.losses.absolute_difference(tf.cast(labels, tf.float32), tf.cast(model_output, tf.float32), reduction=tf.losses.Reduction.MEAN)
        errors = tf.expand_dims(errors,axis=0)
        error_averaging = hvd.allreduce(errors)
        if num_batches is not None:
            num_samples = num_batches
        else:
            num_samples = dset.num_samples
        #error = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(4)])
        error = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples//params['batch_size'])])
        print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, error.mean()))
    elif hyper_params['network_type'] == 'inverter':
        loss_params = hyper_params['loss_function']
        if labels.shape.as_list()[1] > 1:
@@ -793,6 +811,13 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        if loss_params['type'] == 'MSE_PAIR':
            errors = tf.losses.mean_pairwise_squared_error(tf.cast(labels, tf.float32), tf.cast(n_net.model_output, tf.float32))
            loss_label= loss_params['type'] 
        elif loss_params['type'] == 'rMSE':
            labels = tf.cast(labels, tf.float32)
            l2_true = tf.sqrt(tf.reduce_sum(labels ** 2, axis=[1,2,3]))
            l2_output = tf.sqrt(tf.reduce_sum(n_net.model_output **2, axis = [1,2,3]))
            errors = tf.reduce_mean(tf.abs(l2_true - l2_output)/l2_true)
            errors *= 100
            loss_label= loss_params['type'] 
        else: 
            loss_label= 'ABS_DIFF'
            errors = tf.losses.absolute_difference(tf.cast(labels, tf.float32), tf.cast(n_net.model_output, tf.float32), reduction=tf.losses.Reduction.MEAN)
@@ -803,7 +828,7 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
            num_samples = num_batches
        else:
            num_samples = dset.num_samples
        errors = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples)])
        errors = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples//params['batch_size'])])
        # errors = np.array([sess.run([IO_ops,errors])[-1] for i in range(dset.num_samples)])
        # errors = tf.reduce_mean(errors)
        # avg_errors = hvd.allreduce(tf.expand_dims(errors, axis=0))
@@ -917,6 +942,12 @@ def validate_ckpt(network_config, hyper_params, params, num_batches=None,
            ckpt_paths = [ckpt_paths[-1]]
            model_steps = [model_steps[-1]]

        if params['output']:
            output_dir = os.path.join(os.getcwd(), 'outputs_%s' % params['checkpt_dir'].split('/')[-1])
            if not os.path.exists(output_dir):
                tf.gfile.MakeDirs(output_dir)
                

        # Validate Models
        for ckpt, last_step in zip(ckpt_paths, model_steps):
            #
@@ -950,7 +981,6 @@ def validate_ckpt(network_config, hyper_params, params, num_batches=None,
                if params['output']:
                    output = tf.cast(n_net.model_output, tf.float32)
                    print('output shape',output.get_shape().as_list()) 
                    output_dir = os.path.join(os.getcwd(),'outputs')
                    if num_batches is not None:
                        num_samples = num_batches
                    else:
@@ -977,11 +1007,10 @@ def validate_ckpt(network_config, hyper_params, params, num_batches=None,
                    print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, error.mean()))
            elif hyper_params['network_type'] == 'YNet':
                loss_params = hyper_params['loss_function']
                model_output = tf.stack([n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']], axis=0)
                model_output = tf.concat([n_net.model_output[subnet] for subnet in ['inverter', 'decoder_RE', 'decoder_IM']], axis=1)
                if params['output']:
                    output = tf.cast(model_output, tf.float32)
                    print('output shape',output.get_shape().as_list()) 
                    output_dir = os.path.join(os.getcwd(),'outputs')
                    if num_batches is not None:
                        num_samples = num_batches
                    else:
@@ -998,14 +1027,17 @@ def validate_ckpt(network_config, hyper_params, params, num_batches=None,
                    else: 
                        loss_label= 'ABS_DIFF'
                        errors = tf.losses.absolute_difference(tf.cast(labels, tf.float32), tf.cast(model_output, tf.float32), reduction=tf.losses.Reduction.MEAN)
                    errors = tf.expand_dims(errors,axis=0)
                    error_averaging = hvd.allreduce(errors)
                    #errors = tf.expand_dims(errors,axis=0)
                    #error_averaging = hvd.allreduce(errors)
                    error_averaging = errors
                    if num_batches is not None:
                        num_samples = num_batches
                    else:
                        num_samples = dset.num_samples
                    #error = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(4)])
                    error = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples)])
                    print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, error.mean()))
                    print('Rank=%d, Validation Reconstruction Error %s: %3.3e' % (hvd.rank(),loss_label, error.mean()))
                    #print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, error.mean()))
            if sleep < 0:
                break
            else: