Commit 085b6660 authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

mods to runtime.train to facilitate distributed hyperparams search

parent b3bb25af
Loading
Loading
Loading
Loading
+34 −303
Original line number Diff line number Diff line
@@ -124,14 +124,14 @@ class TrainHelper(object):
    @staticmethod
    def nanloss(loss_value):
        if np.isnan(loss_value):
            print_rank('loss is nan... Exiting!')
            sys.exit(0)
            print_rank('loss is nan...')
            # sys.exit(0)

def print_rank(*args, **kwargs):
    if hvd.rank() == 0:
        print(*args, **kwargs)

def train(network_config, hyper_params, params):
def train(network_config, hyper_params, params, gpu_id=None):
    """
    Train the network for a number of steps using horovod and asynchronous I/O staging ops.

@@ -148,7 +148,9 @@ def train(network_config, hyper_params, params):
                           log_device_placement=params['log_device_placement'],
                           )
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    if gpu_id is None:
        gpu_id = hvd.local_rank() 
    config.gpu_options.visible_device_list = str(gpu_id)
    config.gpu_options.force_gpu_compatible = True
    config.intra_op_parallelism_threads = 6 
    config.inter_op_parallelism_threads = max(1, cpu_count()//6)
@@ -347,17 +349,7 @@ def train(network_config, hyper_params, params):
        print_rank("Restoring from previous checkpoint @ step=%d" % last_step)

    # Train
    # if hvd.rank() == 0:
    #     train_elf = TrainHelper(params, saver, summary_writer,  n_net.get_ops(), last_step=last_step, log_freq=params['log_frequency'])
    # else:
    #     train_elf = TrainHelper(params, saver, None, n_net.get_ops(), last_step=last_step)

    train_elf = TrainHelper(params, saver, summary_writer,  n_net.get_ops(), last_step=last_step, log_freq=params['log_frequency'])
    # if params['restart']:
    #     saveStep = train_elf.last_step + params['save_step']
    #     validateStep = train_elf.last_step + params['validate_step']
    #     summaryStep = train_elf.last_step + params['summary_step'] 
    # else:
    saveStep =  params['save_step']
    validateStep = params['validate_step']
    summaryStep = params['summary_step']
@@ -368,6 +360,9 @@ def train(network_config, hyper_params, params):
    traceStep = params[ 'trace_step' ]
    maxTime = params.get('max_time', 1e12)
    
    val_results = []
    loss_results = []
    loss_value = 1e10
    while train_elf.last_step < maxSteps :
        train_elf.before_run()
        doLog   = bool(train_elf.last_step % logFreq  == 0)
@@ -383,9 +378,11 @@ def train(network_config, hyper_params, params):
            sess.run(train_op)
        elif doLog and not doSave  and not doSumm:
            _, loss_value, lr = sess.run( [ train_op, total_loss, learning_rate ] )
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, lr )
        elif doLog and doSumm and doSave :
            _, summary, loss_value, lr = sess.run( [ train_op, summary_merged, total_loss, learning_rate ])
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, lr )
            train_elf.write_summaries( summary )
            if hvd.rank( ) == 0 :
@@ -393,6 +390,7 @@ def train(network_config, hyper_params, params):
                print_rank('Saved Checkpoint.')
        elif doLog and doSumm :
            _, summary, loss_value, lr = sess.run( [ train_op, summary_merged, total_loss, learning_rate ])
            loss_results.append((train_elf.last_step, loss_value))
            train_elf.log_stats( loss_value, lr )
            train_elf.write_summaries( summary )
        elif doSumm:
@@ -408,292 +406,19 @@ def train(network_config, hyper_params, params):
            train_elf.before_run()
        # Here we do validation:
        if doValidate:
            validate(network_config, hyper_params, params, sess, dset)
        if doFinish:
            return

def train_inverter(network_config, hyper_params, params):
    """
    Train the network for a number of steps using horovod and asynchronous I/O staging ops.

    :param network_config: OrderedDict, network configuration
    :param hyper_params: OrderedDict, hyper_parameters
    :param params: dict
    :return: None
    """
    #########################
    # Start Session         #
    #########################
    # Config file for tf.Session()
    config = tf.ConfigProto(allow_soft_placement=params['allow_soft_placement'],
                           log_device_placement=params['log_device_placement'],
                           )
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.force_gpu_compatible = True
    config.intra_op_parallelism_threads = 6 
    config.inter_op_parallelism_threads = max(1, cpu_count()//6)
    config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
    # JIT causes gcc errors on dgx-dl and is built without on Summit.
    sess = tf.Session(config=config)


    ############################
    # Setting up Checkpointing #
    ###########################

    last_step = 0
    if params[ 'restart' ] :
       # Check if training is a restart from checkpoint
       ckpt = tf.train.get_checkpoint_state(params[ 'checkpt_dir' ] )
       if ckpt is None :
          print_rank( '<ERROR> Could not restart from checkpoint %s' % params[ 'checkpt_dir' ])
       else :
          last_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
          print_rank("Restoring from previous checkpoint @ step=%d" %last_step)

    global_step = tf.Variable(last_step, name='global_step',trainable=False)


    ############################################
    # Setup Graph, Input pipeline and optimizer#
    ############################################
    # Start building the graph

    # Setup data stream
    with tf.device(params['CPU_ID']):
        with tf.name_scope('Input') as _:
            if params['filetype'] == 'tfrecord':
                dset = inputs.DatasetTFRecords(params, dataset=params['dataset'], debug=False)
            elif params['filetype'] == 'lmdb':
                params['data_dir'] = "/data/lmdb_bank_0529"
                dset = inputs.DatasetLMDB(params, dataset=params['dataset'], debug=params['debug'])
            images, labels = dset.minibatch()
            # Staging images on host
            staging_op, (images, labels) = dset.stage([images, labels])

            ### adding another input pipeline
            params['data_dir'] = "/data/lmdb_bank_0531_256"
            dset_2 = inputs.DatasetLMDB(params, dataset=params['dataset'], debug=params['debug'])
            images_2, labels_2 = dset_2.minibatch()
            staging_op_2, (images_2, labels_2) = dset_2.stage([images_2, labels_2])        
    

    with tf.device('/gpu:%d' % hvd.local_rank()):
        # Copy images from host to device
        gpucopy_op, (images, labels) = dset.stage([images, labels])
        IO_ops = [staging_op, gpucopy_op]
        # Copy images from host to device
        gpucopy_op_2, (images_2, labels_2) = dset_2.stage([images_2, labels_2])
        IO_ops_2 = [staging_op_2, gpucopy_op_2]

        ##################
        # Building Model#
        ##################

        # Build model, forward propagate, and calculate loss
        scope = 'model'
        summary = False
        if params['debug']: 
            summary = True
        print_rank('Starting up queue of images+labels: %s,  %s ' % (format(images.get_shape()),
                                                                format(labels.get_shape())))

        with tf.variable_scope(scope,
                # Force all variables to be stored as float32
                custom_getter=float32_variable_storage_getter) as _:

            # Setup Neural Net
            if params['network_class'] == 'resnet':
                n_net = network.ResNet(scope, params, hyper_params, network_config, images, labels,
                                        operation='train', summary=summary, verbose=False)
            if params['network_class'] == 'cnn':
                n_net = network.ConvNet(scope, params, hyper_params, network_config, images, labels,
                                        operation='train', summary=summary, verbose=True)
            if params['network_class'] == 'fcdensenet':
                n_net = network.FCDenseNet(scope, params, hyper_params, network_config, images, labels,
                                        operation='train', summary=summary, verbose=True)
            if params['network_class'] == 'fcnet':
                n_net = network.FCNet(scope, params, hyper_params, network_config, images, labels,
                                        operation='train', summary=summary, verbose=True)
            
                
            ###### XLA compilation #########    
            #if params['network_class'] == 'fcdensenet':
            #    def wrap_n_net(*args):
            #        images, labels = args
            #        n_net = network.FCDenseNet(scope, params, hyper_params, network_config, images, labels,
            #                            operation='train', summary=False, verbose=True)
            #        n_net.build_model()
            #        return n_net.model_output
            #
            #    n_net.model_output = xla.compile(wrap_n_net, inputs=[images, labels])
            ##############################

            # Build it and propagate images through it.
            n_net.build_model()

            # calculate the total loss
            total_loss, loss_averages_op = losses.calc_loss(n_net, scope, hyper_params, params, labels, summary=summary)

            #get summaries, except for the one produced by string_input_producer
            if summary: summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)


        #######################################
        # Apply Gradients and setup train op #
        #######################################

        # get learning policy
        def learning_policy_func(step):
            return lr_policies.decay_warmup(params, hyper_params, step)
            ## TODO: implement other policies in lr_policies

        iter_size = params.get('accumulate_step', 0)
        skip_update_cond = tf.cast(tf.floormod(global_step, tf.constant(iter_size, dtype=tf.int32)), tf.bool)

        if params['IMAGE_FP16']:
            opt_type='mixed'
        else:
            opt_type=tf.float32
        # setup optimizer
        opt_dict = hyper_params['optimization']['params'] 
        train_opt, learning_rate = optimizers.optimize_loss(total_loss, hyper_params['optimization']['name'], 
                                opt_dict, learning_policy_func, run_params=params, hyper_params=hyper_params, iter_size=iter_size, dtype=opt_type, 
                                loss_scaling=hyper_params.get('loss_scaling',1.0), 
                                skip_update_cond=skip_update_cond,
                                on_horovod=True, model_scopes=n_net.scopes)  

    # Gather all training related ops into a single one.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    increment_op = tf.assign_add(global_step, 1)
    ema = tf.train.ExponentialMovingAverage(decay=0.99, num_updates=global_step)
    all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op]))
    all_ops_2 = tf.group(*([train_opt] + update_ops + IO_ops_2 + [increment_op]))

    with tf.control_dependencies([all_ops]):
            train_op = ema.apply(tf.trainable_variables()) 
            # train_op = tf.no_op(name='train')
    with tf.control_dependencies([all_ops_2]):
            train_op_2 = ema.apply(tf.trainable_variables()) 

    ########################
    # Setting up Summaries #
    ########################

    # Stats and summaries
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    if hvd.rank() == 0:
        summary_writer = tf.summary.FileWriter(params['checkpt_dir'], sess.graph)
        # Add Summary histograms for trainable variables and their gradients
    if params['debug'] and hyper_params['network_type'] == 'inverter':
        predic = tf.transpose(n_net.model_output, perm=[0,2,3,1])
        output_summary = tf.summary.image("outputs", predic, max_outputs=4) 
        tf.summary.image("targets", tf.transpose(labels, perm=[0,2,3,1]), max_outputs=4)
        tf.summary.image("inputs", tf.transpose(tf.reduce_mean(images, axis=1, keepdims=True), perm=[0,2,3,1]), max_outputs=4)
    
    summary_merged = tf.summary.merge_all()

     ###############################
    # Setting up training session #
    ###############################

    #Initialize variables
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # Sync
    print_rank('Syncing horovod ranks...')
    sync_op = hvd.broadcast_global_variables(0)
    sess.run(sync_op)
    
    # prefill pipeline first
    print_rank('Prefilling I/O pipeline...')
    for i in range(len(IO_ops)):
        sess.run(IO_ops[:i + 1])
        sess.run(IO_ops_2[:i + 1])


    # Saver and Checkpoint restore
    checkpoint_file = os.path.join(params[ 'checkpt_dir' ], 'model.ckpt')
    saver = tf.train.Saver(max_to_keep=None, save_relative_paths=True)

    # Check if training is a restart from checkpoint
    if params['restart'] and ckpt is not None:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print_rank("Restoring from previous checkpoint @ step=%d" % last_step)

    # Train
    if hvd.rank() == 0:
        train_elf = TrainHelper(params, saver, summary_writer,  n_net.get_ops(), last_step=last_step, log_freq=params['log_frequency'])
    else:
        train_elf = TrainHelper(params, saver, None, n_net.get_ops(), last_step=last_step)

    if params['restart']:
        saveStep = train_elf.last_step + params['save_step']
        validateStep = train_elf.last_step + params['validate_step']
        summaryStep = train_elf.last_step + params['summary_step'] 
    else:
        saveStep =  params['save_step']
        validateStep = params['validate_step']
        summaryStep = params['summary_step']

    train_elf.run_summary( )
    maxSteps  = params[ 'max_steps' ]
    logFreq   = params[ 'log_frequency' ]
    traceStep = params[ 'trace_step' ]

    while train_elf.last_step < maxSteps :
        train_elf.before_run()

        doLog   = train_elf.last_step % logFreq  == 0
        doSave  = train_elf.last_step >= saveStep 
        doSumm  = train_elf.last_step > summaryStep 
        doTrace = train_elf.last_step == traceStep and params['gpu_trace']
        if train_elf.last_step == 1 and params['debug']:
            if hvd.rank() == 0:
                summary = sess.run([train_op,  summary_merged])[-1]
                train_elf.write_summaries( summary )
        if not doLog and not doSave and not doTrace and not doSumm and bool(train_elf.last_step % 2):
            if bool(train_elf.last_step % 2): 
                sess.run(train_op)
            else: 
                sess.run(train_op_2)
        elif doLog and not doSave :
            _, loss_value, lr = sess.run( [ train_op, total_loss, learning_rate ] )
            train_elf.log_stats( loss_value, lr )
        elif doLog and doSave :
            _, summary, loss_value, lr = sess.run( [ train_op, summary_merged, total_loss, learning_rate ])
            train_elf.log_stats( loss_value, lr )
            train_elf.write_summaries( summary )
            if hvd.rank( ) == 0 :
                saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
                print_rank('Saved Checkpoint.')
            saveStep += params['save_step']
        elif doSumm and params['debug']:
            if hvd.rank() == 0:
                summary = sess.run([train_op,  summary_merged])[-1]
                train_elf.write_summaries( summary )
            summaryStep += params['summary_step'] 
        elif doSave :
            #summary = sess.run([train_op,  summary_merged])[-1]
            #train_elf.write_summaries( summary )
            if hvd.rank( ) == 0 :
                saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
                print_rank('Saved Checkpoint.')
            saveStep += params['save_step']
        elif doTrace :
            sess.run(train_op, options=run_options, run_metadata=run_metadata)
            train_elf.save_trace(run_metadata, params[ 'trace_dir' ], params[ 'trace_step' ] )
            train_elf.before_run()
        # Here we do validation:
        #if train_elf.elapsed_epochs > next_validation_epoch:
        if train_elf.last_step > validateStep:
            validate(network_config, hyper_params, params, sess, dset)
            validateStep += params['validate_step']
            #next_validation_epoch += params['epochs_per_validation']
            val = validate(network_config, hyper_params, params, sess, dset)
            val_results.append((train_elf.last_step,val))
        if doFinish or np.isnan(loss_value):
            val = validate(network_config, hyper_params, params, sess, dset)
            val_results.append((train_elf.last_step, val))
            tf.reset_default_graph()
            tf.keras.backend.clear_session()
            sess.close()
            return val_results, loss_results
    tf.reset_default_graph()
    tf.keras.backend.clear_session()
    sess.close()
    return val_results, loss_results

def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
    """
@@ -762,12 +487,14 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        #TODO: implement prediction layer for hybrid network

    # Do evaluation
    result = None
    if hyper_params['network_type'] == 'regressor':
        validation_error = tf.losses.mean_squared_error(labels, predictions=logits, reduction=tf.losses.Reduction.NONE)
        # Average validation error over the batches
        errors = np.array([sess.run(validation_error) for _ in range(num_batches)])
        errors = errors.reshape(-1, params['NUM_CLASSES'])
        avg_errors = errors.mean(0)
        result = avg_errors
        print_rank('Validation MSE: %s' % format(avg_errors))
    elif hyper_params['network_type'] == 'classifier':
        labels = tf.argmax(labels, axis=1)
@@ -780,6 +507,7 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        val_loss = output[:,-1]
        accuracy = accuracy.sum(axis=(0,-1))/(num_batches*params['batch_size'])*100
        val_loss = val_loss.sum()/(num_batches*params['batch_size'])
        result = accuracy
        print_rank('Validation Accuracy (.pct), Top-1: %2.2f , Top-5: %2.2f, Loss: %2.2f' %(accuracy[0], accuracy[1], val_loss))
    elif hyper_params['network_type'] == 'hybrid':
        #TODO: implement evaluation call for hybrid network
@@ -800,8 +528,9 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        else:
            num_samples = dset.num_samples
        #error = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(4)])
        error = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples//params['batch_size'])])
        print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, error.mean()))
        errors = np.array([sess.run([IO_ops,error_averaging])[-1] for i in range(num_samples//params['batch_size'])])
        result = errors.mean()
        print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
    elif hyper_params['network_type'] == 'inverter':
        loss_params = hyper_params['loss_function']
        if labels.shape.as_list()[1] > 1:
@@ -832,8 +561,10 @@ def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
        # avg_errors = hvd.allreduce(tf.expand_dims(errors, axis=0))
        # error = sess.run(avg_errors)
        # print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
        result = errors.mean()
        print_rank('Validation Reconstruction Error %s: %3.3e' % (loss_label, errors.mean()))
        tf.summary.scalar("Validation_loss_label_%s" % loss_label, tf.constant(errors.mean()))
    return result

def validate_ckpt(network_config, hyper_params, params, num_batches=None,
                    last_model= False, sleep=-1):