Commit 6d8ddd4d authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

new training function and class for YNet

parent 10d37686
Loading
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -118,6 +118,8 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
    total_loss = tf.cast(total_loss, tf.float32)
    # Generate summaries for the losses and get corresponding op
    loss_averages_op = _add_loss_summaries(total_loss, losses, summaries=summary)
    if hyper_params['network_type'] == 'YNet':
        return total_loss, loss_averages_op, losses 
    return total_loss, loss_averages_op

def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None):
+292 −1
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
        variable = tf.cast(variable, dtype)
    return variable

class TrainHelper(object):
class TrainHelper:
    def __init__(self, params, saver, writer, net_ops, last_step=0, log_freq=1):
        self.params = params
        self.last_step = last_step
@@ -127,6 +127,22 @@ class TrainHelper(object):
            print_rank('loss is nan...')
            # sys.exit(0)

class TrainHelper_YNet(TrainHelper):
    def log_stats(self, loss_value, aux_losses, learning_rate):
        t = time.time( )
        duration = t - self.start_time
        examples_per_sec = self.params['batch_size'] * hvd.size() / duration
        self.cumm_time = (time.time() - self.cumm_time)/self.log_freq
        flops = self.net_ops * examples_per_sec
        avg_flops = self.net_ops * self.params['batch_size'] * hvd.size() / self.cumm_time
        loss_inv, loss_dec_re, loss_dec_im, loss_reg = aux_losses
        self.nanloss(loss_value) 
        format_str = (
        'time= %.1f, step= %2.2e, epoch= %2.2e, lr= %.2e, loss=%.3e, loss_inv= %.2e, loss_dec_im=%.2e, loss_dec_re=%.2e, loss_reg=%.2e, step_time= %2.2f sec, ranks= %d, examples/sec= %.1f')
        print_rank(format_str % ( t - self.params[ 'start_time' ],  self.last_step, self.elapsed_epochs,
                    learning_rate, loss_value, loss_inv, loss_dec_im, loss_dec_re, loss_reg, duration, hvd.size(), examples_per_sec))
        self.cumm_time = time.time()

def print_rank(*args, **kwargs):
    if hvd.rank() == 0:
        print(*args, **kwargs)
@@ -424,6 +440,281 @@ def train(network_config, hyper_params, params, gpu_id=None):
    sess.close()
    return val_results, loss_results

def train_YNet(network_config, hyper_params, params, gpu_id=None):
    """
    Train the network for a number of steps using horovod and asynchronous I/O staging ops.

    :param network_config: OrderedDict, network configuration
    :param hyper_params: OrderedDict, hyper_parameters
    :param params: dict
    :return: None
    """
    #########################
    # Start Session         #
    #########################
    # Config file for tf.Session()
    config = tf.ConfigProto(allow_soft_placement=params['allow_soft_placement'],
                           log_device_placement=params['log_device_placement'],
                           )
    config.gpu_options.allow_growth = True
    if gpu_id is None:
        gpu_id = hvd.local_rank() 
    config.gpu_options.visible_device_list = str(gpu_id)
    config.gpu_options.force_gpu_compatible = True
    config.intra_op_parallelism_threads = 6 
    config.inter_op_parallelism_threads = max(1, cpu_count()//6)
    #config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    #jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
    # JIT causes gcc errors on dgx-dl and is built without on Summit.
    sess = tf.Session(config=config)


    ############################
    # Setting up Checkpointing #
    ###########################

    last_step = 0
    if params[ 'restart' ] :
       # Check if training is a restart from checkpoint
       ckpt = tf.train.get_checkpoint_state(params[ 'checkpt_dir' ] )
       if ckpt is None :
          print_rank( '<ERROR> Could not restart from checkpoint %s' % params[ 'checkpt_dir' ])
       else :
          last_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
          print_rank("Restoring from previous checkpoint @ step=%d" %last_step)

    global_step = tf.Variable(last_step, name='global_step',trainable=False)


    ############################################
    # Setup Graph, Input pipeline and optimizer#
    ############################################
    # Start building the graph

    # Setup data stream
    with tf.device(params['CPU_ID']):
        with tf.name_scope('Input') as _:
            if params['filetype'] == 'tfrecord':
                dset = inputs.DatasetTFRecords(params, dataset=params['dataset'], debug=False)
            elif params['filetype'] == 'lmdb':
                dset = inputs.DatasetLMDB(params, dataset=params['dataset'], debug=params['debug'])
            images, labels = dset.minibatch()
            # Staging images on host
            staging_op, (images, labels) = dset.stage([images, labels])

    with tf.device('/gpu:%d' % hvd.local_rank()):
        # Copy images from host to device
        gpucopy_op, (images, labels) = dset.stage([images, labels])
        IO_ops = [staging_op, gpucopy_op]

        ##################
        # Building Model#
        ##################

        # Build model, forward propagate, and calculate loss
        scope = 'model'
        summary = False
        if params['debug']: 
            summary = True
        print_rank('Starting up queue of images+labels: %s,  %s ' % (format(images.get_shape()),
                                                                format(labels.get_shape())))

        with tf.variable_scope(scope,
                # Force all variables to be stored as float32
                custom_getter=float32_variable_storage_getter) as _:

            # Setup Neural Net
            n_net = network.YNet(scope, params, hyper_params, network_config, images, labels,
                                    operation='train', summary=summary, verbose=True)
            
                
            ###### XLA compilation #########    
            #if params['network_class'] == 'fcdensenet':
            #    def wrap_n_net(*args):
            #        images, labels = args
            #        n_net = network.FCDenseNet(scope, params, hyper_params, network_config, images, labels,
            #                            operation='train', summary=False, verbose=True)
            #        n_net.build_model()
            #        return n_net.model_output
            #
            #    n_net.model_output = xla.compile(wrap_n_net, inputs=[images, labels])
            ##############################

            # Build it and propagate images through it.
            n_net.build_model()

            # calculate the total loss
            total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary)

            #get summaries, except for the one produced by string_input_producer
            if summary: summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
            # print_rank([scope.name for scope in n_net.scopes])

        #######################################
        # Apply Gradients and setup train op #
        #######################################

        # get learning policy
        def learning_policy_func(step):
            return lr_policies.decay_warmup(params, hyper_params, step)
            ## TODO: implement other policies in lr_policies

        iter_size = params.get('accumulate_step', 0)
        skip_update_cond = tf.cast(tf.floormod(global_step, tf.constant(iter_size, dtype=tf.int32)), tf.bool)

        if params['IMAGE_FP16']:
            opt_type='mixed'
        else:
            opt_type=tf.float32
        # setup optimizer
        opt_dict = hyper_params['optimization']['params'] 
        train_opt, learning_rate = optimizers.optimize_loss(total_loss, hyper_params['optimization']['name'], 
                                opt_dict, learning_policy_func, run_params=params, hyper_params=hyper_params, iter_size=iter_size, dtype=opt_type, 
                                loss_scaling=hyper_params.get('loss_scaling',1.0), 
                                skip_update_cond=skip_update_cond,
                                on_horovod=True, model_scopes=n_net.scopes)  

    # Gather all training related ops into a single one.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    increment_op = tf.assign_add(global_step, 1)
    ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step)
    all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op]))

    with tf.control_dependencies([all_ops]):
            train_op = ema.apply(tf.trainable_variables()) 
            # train_op = tf.no_op(name='train')

    ########################
    # Setting up Summaries #
    ########################

    # Stats and summaries
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    # if hvd.rank() == 0:
    summary_writer = tf.summary.FileWriter(os.path.join(params['checkpt_dir'], str(hvd.rank())), sess.graph)
        # Add Summary histograms for trainable variables and their gradients
    if params['debug']:
        predic_inverter = tf.transpose(n_net.model_output['inverter'], perm=[0,2,3,1])
        tf.summary.image("output_inverter", predic_inverter, max_outputs=2) 
        predic_decoder_RE = tf.transpose(n_net.model_output['decoder_RE'], perm=[0,2,3,1])
        predic_decoder_IM = tf.transpose(n_net.model_output['decoder_IM'], perm=[0,2,3,1])
        tf.summary.image("output_decoder_RE", predic_decoder_RE, max_outputs=2)
        tf.summary.image("output_decoder_IM", predic_decoder_IM, max_outputs=2)
        new_labels = tf.unstack(labels, axis=1)
        for label, tag in zip(new_labels, ['potential', 'probe_RE', 'probe_IM']):
            label = tf.expand_dims(label, axis=-1)
            # label = tf.transpose(label, perm=[0,2,3,1])  
            tf.summary.image(tag, label, max_outputs=2)
        tf.summary.image("inputs", tf.transpose(tf.reduce_mean(images, axis=1, keepdims=True), perm=[0,2,3,1]), max_outputs=4) 
          
    summary_merged = tf.summary.merge_all()

     ###############################
    # Setting up training session #
    ###############################

    #Initialize variables
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # Sync
    print_rank('Syncing horovod ranks...')
    sync_op = hvd.broadcast_global_variables(0)
    sess.run(sync_op)
    
    # prefill pipeline first
    print_rank('Prefilling I/O pipeline...')
    for i in range(len(IO_ops)):
        sess.run(IO_ops[:i + 1])


    # Saver and Checkpoint restore
    checkpoint_file = os.path.join(params[ 'checkpt_dir' ], 'model.ckpt')
    saver = tf.train.Saver(max_to_keep=None, save_relative_paths=True)

    # Check if training is a restart from checkpoint
    if params['restart'] and ckpt is not None:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print_rank("Restoring from previous checkpoint @ step=%d" % last_step)

    # Train
    train_elf = TrainHelper_YNet(params, saver, summary_writer,  n_net.get_ops(), last_step=last_step, log_freq=params['log_frequency'])
    saveStep =  params['save_step']
    validateStep = params['validate_step']
    summaryStep = params['summary_step']

    train_elf.run_summary()
    maxSteps  = params[ 'max_steps' ]
    logFreq   = params[ 'log_frequency' ]
    traceStep = params[ 'trace_step' ]
    maxTime = params.get('max_time', 1e12)
    
    val_results = []
    loss_results = []
    loss_value = 1e10
    val = 1e10
    while train_elf.last_step < maxSteps :
        train_elf.before_run()
        doLog   = bool(train_elf.last_step % logFreq  == 0)
        doSave  = bool(train_elf.last_step % saveStep == 0)
        doSumm  = bool(train_elf.last_step % summaryStep == 0 and params['debug'])
        doTrace = bool(train_elf.last_step == traceStep and params['gpu_trace'])
        doValidate = bool(train_elf.last_step % validateStep == 0)
        doFinish = bool(train_elf.start_time - params['start_time'] > maxTime)
        if train_elf.last_step == 1 and params['debug']:
            summary = sess.run([train_op,  summary_merged])[-1]
            train_elf.write_summaries( summary )
        elif not doLog and not doSave and not doTrace and not doSumm:
            sess.run(train_op)
        elif doLog and not doSave  and not doSumm:
            _, lr, loss_value, aux_losses = sess.run( [ train_op, learning_rate, total_loss, indv_losses])
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, aux_losses, lr)
        elif doLog and doSumm and doSave :
            _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses,
                                                             learning_rate ])
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, aux_losses, lr )
            train_elf.write_summaries( summary )
            if hvd.rank( ) == 0 :
                saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
                print_rank('Saved Checkpoint.')
        elif doLog and doSumm :
            _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ])
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, aux_losses, lr )
            train_elf.write_summaries( summary )
        elif doSumm:
            summary = sess.run([train_op,  summary_merged])[-1]
            train_elf.write_summaries( summary )
        elif doSave :
            if hvd.rank( ) == 0 :
                saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
                print_rank('Saved Checkpoint.')
        elif doTrace :
            sess.run(train_op, options=run_options, run_metadata=run_metadata)
            train_elf.save_trace(run_metadata, params[ 'trace_dir' ], params[ 'trace_step' ] )
            train_elf.before_run()
        # Here we do validation:
        if doValidate:
            val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
            val_results.append((train_elf.last_step,val))
        if doFinish: 
            #val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
            #val_results.append((train_elf.last_step, val))
            tf.reset_default_graph()
            tf.keras.backend.clear_session()
            sess.close()
            return val_results, loss_results
        if np.isnan(loss_value):
            break
    val_results.append((train_elf.last_step,val))
    tf.reset_default_graph()
    tf.keras.backend.clear_session()
    sess.close()
    return val_results, loss_results

def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
    """
    Runs validation with current weights