Loading stemdl/losses.py +2 −0 Original line number Diff line number Diff line Loading @@ -118,6 +118,8 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None total_loss = tf.cast(total_loss, tf.float32) # Generate summaries for the losses and get corresponding op loss_averages_op = _add_loss_summaries(total_loss, losses, summaries=summary) if hyper_params['network_type'] == 'YNet': return total_loss, loss_averages_op, losses return total_loss, loss_averages_op def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None): Loading stemdl/runtime.py +292 −1 Original line number Diff line number Diff line Loading @@ -48,7 +48,7 @@ def float32_variable_storage_getter(getter, name, shape=None, dtype=None, variable = tf.cast(variable, dtype) return variable class TrainHelper(object): class TrainHelper: def __init__(self, params, saver, writer, net_ops, last_step=0, log_freq=1): self.params = params self.last_step = last_step Loading Loading @@ -127,6 +127,22 @@ class TrainHelper(object): print_rank('loss is nan...') # sys.exit(0) class TrainHelper_YNet(TrainHelper): def log_stats(self, loss_value, aux_losses, learning_rate): t = time.time( ) duration = t - self.start_time examples_per_sec = self.params['batch_size'] * hvd.size() / duration self.cumm_time = (time.time() - self.cumm_time)/self.log_freq flops = self.net_ops * examples_per_sec avg_flops = self.net_ops * self.params['batch_size'] * hvd.size() / self.cumm_time loss_inv, loss_dec_re, loss_dec_im, loss_reg = aux_losses self.nanloss(loss_value) format_str = ( 'time= %.1f, step= %2.2e, epoch= %2.2e, lr= %.2e, loss=%.3e, loss_inv= %.2e, loss_dec_im=%.2e, loss_dec_re=%.2e, loss_reg=%.2e, step_time= %2.2f sec, ranks= %d, examples/sec= %.1f') print_rank(format_str % ( t - self.params[ 'start_time' ], self.last_step, self.elapsed_epochs, learning_rate, loss_value, loss_inv, loss_dec_im, loss_dec_re, loss_reg, duration, hvd.size(), examples_per_sec)) self.cumm_time = time.time() def print_rank(*args, **kwargs): if hvd.rank() == 0: print(*args, **kwargs) Loading Loading @@ -424,6 +440,281 @@ def train(network_config, hyper_params, params, gpu_id=None): sess.close() return val_results, loss_results def train_YNet(network_config, hyper_params, params, gpu_id=None): """ Train the network for a number of steps using horovod and asynchronous I/O staging ops. :param network_config: OrderedDict, network configuration :param hyper_params: OrderedDict, hyper_parameters :param params: dict :return: None """ ######################### # Start Session # ######################### # Config file for tf.Session() config = tf.ConfigProto(allow_soft_placement=params['allow_soft_placement'], log_device_placement=params['log_device_placement'], ) config.gpu_options.allow_growth = True if gpu_id is None: gpu_id = hvd.local_rank() config.gpu_options.visible_device_list = str(gpu_id) config.gpu_options.force_gpu_compatible = True config.intra_op_parallelism_threads = 6 config.inter_op_parallelism_threads = max(1, cpu_count()//6) #config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 #jit_scope = tf.contrib.compiler.jit.experimental_jit_scope # JIT causes gcc errors on dgx-dl and is built without on Summit. sess = tf.Session(config=config) ############################ # Setting up Checkpointing # ########################### last_step = 0 if params[ 'restart' ] : # Check if training is a restart from checkpoint ckpt = tf.train.get_checkpoint_state(params[ 'checkpt_dir' ] ) if ckpt is None : print_rank( '<ERROR> Could not restart from checkpoint %s' % params[ 'checkpt_dir' ]) else : last_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) print_rank("Restoring from previous checkpoint @ step=%d" %last_step) global_step = tf.Variable(last_step, name='global_step',trainable=False) ############################################ # Setup Graph, Input pipeline and optimizer# ############################################ # Start building the graph # Setup data stream with tf.device(params['CPU_ID']): with tf.name_scope('Input') as _: if params['filetype'] == 'tfrecord': dset = inputs.DatasetTFRecords(params, dataset=params['dataset'], debug=False) elif params['filetype'] == 'lmdb': dset = inputs.DatasetLMDB(params, dataset=params['dataset'], debug=params['debug']) images, labels = dset.minibatch() # Staging images on host staging_op, (images, labels) = dset.stage([images, labels]) with tf.device('/gpu:%d' % hvd.local_rank()): # Copy images from host to device gpucopy_op, (images, labels) = dset.stage([images, labels]) IO_ops = [staging_op, gpucopy_op] ################## # Building Model# ################## # Build model, forward propagate, and calculate loss scope = 'model' summary = False if params['debug']: summary = True print_rank('Starting up queue of images+labels: %s, %s ' % (format(images.get_shape()), format(labels.get_shape()))) with tf.variable_scope(scope, # Force all variables to be stored as float32 custom_getter=float32_variable_storage_getter) as _: # Setup Neural Net n_net = network.YNet(scope, params, hyper_params, network_config, images, labels, operation='train', summary=summary, verbose=True) ###### XLA compilation ######### #if params['network_class'] == 'fcdensenet': # def wrap_n_net(*args): # images, labels = args # n_net = network.FCDenseNet(scope, params, hyper_params, network_config, images, labels, # operation='train', summary=False, verbose=True) # n_net.build_model() # return n_net.model_output # # n_net.model_output = xla.compile(wrap_n_net, inputs=[images, labels]) ############################## # Build it and propagate images through it. n_net.build_model() # calculate the total loss total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary) #get summaries, except for the one produced by string_input_producer if summary: summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) # print_rank([scope.name for scope in n_net.scopes]) ####################################### # Apply Gradients and setup train op # ####################################### # get learning policy def learning_policy_func(step): return lr_policies.decay_warmup(params, hyper_params, step) ## TODO: implement other policies in lr_policies iter_size = params.get('accumulate_step', 0) skip_update_cond = tf.cast(tf.floormod(global_step, tf.constant(iter_size, dtype=tf.int32)), tf.bool) if params['IMAGE_FP16']: opt_type='mixed' else: opt_type=tf.float32 # setup optimizer opt_dict = hyper_params['optimization']['params'] train_opt, learning_rate = optimizers.optimize_loss(total_loss, hyper_params['optimization']['name'], opt_dict, learning_policy_func, run_params=params, hyper_params=hyper_params, iter_size=iter_size, dtype=opt_type, loss_scaling=hyper_params.get('loss_scaling',1.0), skip_update_cond=skip_update_cond, on_horovod=True, model_scopes=n_net.scopes) # Gather all training related ops into a single one. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) increment_op = tf.assign_add(global_step, 1) ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step) all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op])) with tf.control_dependencies([all_ops]): train_op = ema.apply(tf.trainable_variables()) # train_op = tf.no_op(name='train') ######################## # Setting up Summaries # ######################## # Stats and summaries run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() # if hvd.rank() == 0: summary_writer = tf.summary.FileWriter(os.path.join(params['checkpt_dir'], str(hvd.rank())), sess.graph) # Add Summary histograms for trainable variables and their gradients if params['debug']: predic_inverter = tf.transpose(n_net.model_output['inverter'], perm=[0,2,3,1]) tf.summary.image("output_inverter", predic_inverter, max_outputs=2) predic_decoder_RE = tf.transpose(n_net.model_output['decoder_RE'], perm=[0,2,3,1]) predic_decoder_IM = tf.transpose(n_net.model_output['decoder_IM'], perm=[0,2,3,1]) tf.summary.image("output_decoder_RE", predic_decoder_RE, max_outputs=2) tf.summary.image("output_decoder_IM", predic_decoder_IM, max_outputs=2) new_labels = tf.unstack(labels, axis=1) for label, tag in zip(new_labels, ['potential', 'probe_RE', 'probe_IM']): label = tf.expand_dims(label, axis=-1) # label = tf.transpose(label, perm=[0,2,3,1]) tf.summary.image(tag, label, max_outputs=2) tf.summary.image("inputs", tf.transpose(tf.reduce_mean(images, axis=1, keepdims=True), perm=[0,2,3,1]), max_outputs=4) summary_merged = tf.summary.merge_all() ############################### # Setting up training session # ############################### #Initialize variables init_op = tf.global_variables_initializer() sess.run(init_op) # Sync print_rank('Syncing horovod ranks...') sync_op = hvd.broadcast_global_variables(0) sess.run(sync_op) # prefill pipeline first print_rank('Prefilling I/O pipeline...') for i in range(len(IO_ops)): sess.run(IO_ops[:i + 1]) # Saver and Checkpoint restore checkpoint_file = os.path.join(params[ 'checkpt_dir' ], 'model.ckpt') saver = tf.train.Saver(max_to_keep=None, save_relative_paths=True) # Check if training is a restart from checkpoint if params['restart'] and ckpt is not None: saver.restore(sess, ckpt.model_checkpoint_path) print_rank("Restoring from previous checkpoint @ step=%d" % last_step) # Train train_elf = TrainHelper_YNet(params, saver, summary_writer, n_net.get_ops(), last_step=last_step, log_freq=params['log_frequency']) saveStep = params['save_step'] validateStep = params['validate_step'] summaryStep = params['summary_step'] train_elf.run_summary() maxSteps = params[ 'max_steps' ] logFreq = params[ 'log_frequency' ] traceStep = params[ 'trace_step' ] maxTime = params.get('max_time', 1e12) val_results = [] loss_results = [] loss_value = 1e10 val = 1e10 while train_elf.last_step < maxSteps : train_elf.before_run() doLog = bool(train_elf.last_step % logFreq == 0) doSave = bool(train_elf.last_step % saveStep == 0) doSumm = bool(train_elf.last_step % summaryStep == 0 and params['debug']) doTrace = bool(train_elf.last_step == traceStep and params['gpu_trace']) doValidate = bool(train_elf.last_step % validateStep == 0) doFinish = bool(train_elf.start_time - params['start_time'] > maxTime) if train_elf.last_step == 1 and params['debug']: summary = sess.run([train_op, summary_merged])[-1] train_elf.write_summaries( summary ) elif not doLog and not doSave and not doTrace and not doSumm: sess.run(train_op) elif doLog and not doSave and not doSumm: _, lr, loss_value, aux_losses = sess.run( [ train_op, learning_rate, total_loss, indv_losses]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr) elif doLog and doSumm and doSave : _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr ) train_elf.write_summaries( summary ) if hvd.rank( ) == 0 : saver.save(sess, checkpoint_file, global_step=train_elf.last_step) print_rank('Saved Checkpoint.') elif doLog and doSumm : _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr ) train_elf.write_summaries( summary ) elif doSumm: summary = sess.run([train_op, summary_merged])[-1] train_elf.write_summaries( summary ) elif doSave : if hvd.rank( ) == 0 : saver.save(sess, checkpoint_file, global_step=train_elf.last_step) print_rank('Saved Checkpoint.') elif doTrace : sess.run(train_op, options=run_options, run_metadata=run_metadata) train_elf.save_trace(run_metadata, params[ 'trace_dir' ], params[ 'trace_step' ] ) train_elf.before_run() # Here we do validation: if doValidate: val = validate(network_config, hyper_params, params, sess, dset, num_batches=50) val_results.append((train_elf.last_step,val)) if doFinish: #val = validate(network_config, hyper_params, params, sess, dset, num_batches=50) #val_results.append((train_elf.last_step, val)) tf.reset_default_graph() tf.keras.backend.clear_session() sess.close() return val_results, loss_results if np.isnan(loss_value): break val_results.append((train_elf.last_step,val)) tf.reset_default_graph() tf.keras.backend.clear_session() sess.close() return val_results, loss_results def validate(network_config, hyper_params, params, sess, dset, num_batches=10): """ Runs validation with current weights Loading Loading
stemdl/losses.py +2 −0 Original line number Diff line number Diff line Loading @@ -118,6 +118,8 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None total_loss = tf.cast(total_loss, tf.float32) # Generate summaries for the losses and get corresponding op loss_averages_op = _add_loss_summaries(total_loss, losses, summaries=summary) if hyper_params['network_type'] == 'YNet': return total_loss, loss_averages_op, losses return total_loss, loss_averages_op def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None): Loading
stemdl/runtime.py +292 −1 Original line number Diff line number Diff line Loading @@ -48,7 +48,7 @@ def float32_variable_storage_getter(getter, name, shape=None, dtype=None, variable = tf.cast(variable, dtype) return variable class TrainHelper(object): class TrainHelper: def __init__(self, params, saver, writer, net_ops, last_step=0, log_freq=1): self.params = params self.last_step = last_step Loading Loading @@ -127,6 +127,22 @@ class TrainHelper(object): print_rank('loss is nan...') # sys.exit(0) class TrainHelper_YNet(TrainHelper): def log_stats(self, loss_value, aux_losses, learning_rate): t = time.time( ) duration = t - self.start_time examples_per_sec = self.params['batch_size'] * hvd.size() / duration self.cumm_time = (time.time() - self.cumm_time)/self.log_freq flops = self.net_ops * examples_per_sec avg_flops = self.net_ops * self.params['batch_size'] * hvd.size() / self.cumm_time loss_inv, loss_dec_re, loss_dec_im, loss_reg = aux_losses self.nanloss(loss_value) format_str = ( 'time= %.1f, step= %2.2e, epoch= %2.2e, lr= %.2e, loss=%.3e, loss_inv= %.2e, loss_dec_im=%.2e, loss_dec_re=%.2e, loss_reg=%.2e, step_time= %2.2f sec, ranks= %d, examples/sec= %.1f') print_rank(format_str % ( t - self.params[ 'start_time' ], self.last_step, self.elapsed_epochs, learning_rate, loss_value, loss_inv, loss_dec_im, loss_dec_re, loss_reg, duration, hvd.size(), examples_per_sec)) self.cumm_time = time.time() def print_rank(*args, **kwargs): if hvd.rank() == 0: print(*args, **kwargs) Loading Loading @@ -424,6 +440,281 @@ def train(network_config, hyper_params, params, gpu_id=None): sess.close() return val_results, loss_results def train_YNet(network_config, hyper_params, params, gpu_id=None): """ Train the network for a number of steps using horovod and asynchronous I/O staging ops. :param network_config: OrderedDict, network configuration :param hyper_params: OrderedDict, hyper_parameters :param params: dict :return: None """ ######################### # Start Session # ######################### # Config file for tf.Session() config = tf.ConfigProto(allow_soft_placement=params['allow_soft_placement'], log_device_placement=params['log_device_placement'], ) config.gpu_options.allow_growth = True if gpu_id is None: gpu_id = hvd.local_rank() config.gpu_options.visible_device_list = str(gpu_id) config.gpu_options.force_gpu_compatible = True config.intra_op_parallelism_threads = 6 config.inter_op_parallelism_threads = max(1, cpu_count()//6) #config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 #jit_scope = tf.contrib.compiler.jit.experimental_jit_scope # JIT causes gcc errors on dgx-dl and is built without on Summit. sess = tf.Session(config=config) ############################ # Setting up Checkpointing # ########################### last_step = 0 if params[ 'restart' ] : # Check if training is a restart from checkpoint ckpt = tf.train.get_checkpoint_state(params[ 'checkpt_dir' ] ) if ckpt is None : print_rank( '<ERROR> Could not restart from checkpoint %s' % params[ 'checkpt_dir' ]) else : last_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) print_rank("Restoring from previous checkpoint @ step=%d" %last_step) global_step = tf.Variable(last_step, name='global_step',trainable=False) ############################################ # Setup Graph, Input pipeline and optimizer# ############################################ # Start building the graph # Setup data stream with tf.device(params['CPU_ID']): with tf.name_scope('Input') as _: if params['filetype'] == 'tfrecord': dset = inputs.DatasetTFRecords(params, dataset=params['dataset'], debug=False) elif params['filetype'] == 'lmdb': dset = inputs.DatasetLMDB(params, dataset=params['dataset'], debug=params['debug']) images, labels = dset.minibatch() # Staging images on host staging_op, (images, labels) = dset.stage([images, labels]) with tf.device('/gpu:%d' % hvd.local_rank()): # Copy images from host to device gpucopy_op, (images, labels) = dset.stage([images, labels]) IO_ops = [staging_op, gpucopy_op] ################## # Building Model# ################## # Build model, forward propagate, and calculate loss scope = 'model' summary = False if params['debug']: summary = True print_rank('Starting up queue of images+labels: %s, %s ' % (format(images.get_shape()), format(labels.get_shape()))) with tf.variable_scope(scope, # Force all variables to be stored as float32 custom_getter=float32_variable_storage_getter) as _: # Setup Neural Net n_net = network.YNet(scope, params, hyper_params, network_config, images, labels, operation='train', summary=summary, verbose=True) ###### XLA compilation ######### #if params['network_class'] == 'fcdensenet': # def wrap_n_net(*args): # images, labels = args # n_net = network.FCDenseNet(scope, params, hyper_params, network_config, images, labels, # operation='train', summary=False, verbose=True) # n_net.build_model() # return n_net.model_output # # n_net.model_output = xla.compile(wrap_n_net, inputs=[images, labels]) ############################## # Build it and propagate images through it. n_net.build_model() # calculate the total loss total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary) #get summaries, except for the one produced by string_input_producer if summary: summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) # print_rank([scope.name for scope in n_net.scopes]) ####################################### # Apply Gradients and setup train op # ####################################### # get learning policy def learning_policy_func(step): return lr_policies.decay_warmup(params, hyper_params, step) ## TODO: implement other policies in lr_policies iter_size = params.get('accumulate_step', 0) skip_update_cond = tf.cast(tf.floormod(global_step, tf.constant(iter_size, dtype=tf.int32)), tf.bool) if params['IMAGE_FP16']: opt_type='mixed' else: opt_type=tf.float32 # setup optimizer opt_dict = hyper_params['optimization']['params'] train_opt, learning_rate = optimizers.optimize_loss(total_loss, hyper_params['optimization']['name'], opt_dict, learning_policy_func, run_params=params, hyper_params=hyper_params, iter_size=iter_size, dtype=opt_type, loss_scaling=hyper_params.get('loss_scaling',1.0), skip_update_cond=skip_update_cond, on_horovod=True, model_scopes=n_net.scopes) # Gather all training related ops into a single one. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) increment_op = tf.assign_add(global_step, 1) ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step) all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op])) with tf.control_dependencies([all_ops]): train_op = ema.apply(tf.trainable_variables()) # train_op = tf.no_op(name='train') ######################## # Setting up Summaries # ######################## # Stats and summaries run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() # if hvd.rank() == 0: summary_writer = tf.summary.FileWriter(os.path.join(params['checkpt_dir'], str(hvd.rank())), sess.graph) # Add Summary histograms for trainable variables and their gradients if params['debug']: predic_inverter = tf.transpose(n_net.model_output['inverter'], perm=[0,2,3,1]) tf.summary.image("output_inverter", predic_inverter, max_outputs=2) predic_decoder_RE = tf.transpose(n_net.model_output['decoder_RE'], perm=[0,2,3,1]) predic_decoder_IM = tf.transpose(n_net.model_output['decoder_IM'], perm=[0,2,3,1]) tf.summary.image("output_decoder_RE", predic_decoder_RE, max_outputs=2) tf.summary.image("output_decoder_IM", predic_decoder_IM, max_outputs=2) new_labels = tf.unstack(labels, axis=1) for label, tag in zip(new_labels, ['potential', 'probe_RE', 'probe_IM']): label = tf.expand_dims(label, axis=-1) # label = tf.transpose(label, perm=[0,2,3,1]) tf.summary.image(tag, label, max_outputs=2) tf.summary.image("inputs", tf.transpose(tf.reduce_mean(images, axis=1, keepdims=True), perm=[0,2,3,1]), max_outputs=4) summary_merged = tf.summary.merge_all() ############################### # Setting up training session # ############################### #Initialize variables init_op = tf.global_variables_initializer() sess.run(init_op) # Sync print_rank('Syncing horovod ranks...') sync_op = hvd.broadcast_global_variables(0) sess.run(sync_op) # prefill pipeline first print_rank('Prefilling I/O pipeline...') for i in range(len(IO_ops)): sess.run(IO_ops[:i + 1]) # Saver and Checkpoint restore checkpoint_file = os.path.join(params[ 'checkpt_dir' ], 'model.ckpt') saver = tf.train.Saver(max_to_keep=None, save_relative_paths=True) # Check if training is a restart from checkpoint if params['restart'] and ckpt is not None: saver.restore(sess, ckpt.model_checkpoint_path) print_rank("Restoring from previous checkpoint @ step=%d" % last_step) # Train train_elf = TrainHelper_YNet(params, saver, summary_writer, n_net.get_ops(), last_step=last_step, log_freq=params['log_frequency']) saveStep = params['save_step'] validateStep = params['validate_step'] summaryStep = params['summary_step'] train_elf.run_summary() maxSteps = params[ 'max_steps' ] logFreq = params[ 'log_frequency' ] traceStep = params[ 'trace_step' ] maxTime = params.get('max_time', 1e12) val_results = [] loss_results = [] loss_value = 1e10 val = 1e10 while train_elf.last_step < maxSteps : train_elf.before_run() doLog = bool(train_elf.last_step % logFreq == 0) doSave = bool(train_elf.last_step % saveStep == 0) doSumm = bool(train_elf.last_step % summaryStep == 0 and params['debug']) doTrace = bool(train_elf.last_step == traceStep and params['gpu_trace']) doValidate = bool(train_elf.last_step % validateStep == 0) doFinish = bool(train_elf.start_time - params['start_time'] > maxTime) if train_elf.last_step == 1 and params['debug']: summary = sess.run([train_op, summary_merged])[-1] train_elf.write_summaries( summary ) elif not doLog and not doSave and not doTrace and not doSumm: sess.run(train_op) elif doLog and not doSave and not doSumm: _, lr, loss_value, aux_losses = sess.run( [ train_op, learning_rate, total_loss, indv_losses]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr) elif doLog and doSumm and doSave : _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr ) train_elf.write_summaries( summary ) if hvd.rank( ) == 0 : saver.save(sess, checkpoint_file, global_step=train_elf.last_step) print_rank('Saved Checkpoint.') elif doLog and doSumm : _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr ) train_elf.write_summaries( summary ) elif doSumm: summary = sess.run([train_op, summary_merged])[-1] train_elf.write_summaries( summary ) elif doSave : if hvd.rank( ) == 0 : saver.save(sess, checkpoint_file, global_step=train_elf.last_step) print_rank('Saved Checkpoint.') elif doTrace : sess.run(train_op, options=run_options, run_metadata=run_metadata) train_elf.save_trace(run_metadata, params[ 'trace_dir' ], params[ 'trace_step' ] ) train_elf.before_run() # Here we do validation: if doValidate: val = validate(network_config, hyper_params, params, sess, dset, num_batches=50) val_results.append((train_elf.last_step,val)) if doFinish: #val = validate(network_config, hyper_params, params, sess, dset, num_batches=50) #val_results.append((train_elf.last_step, val)) tf.reset_default_graph() tf.keras.backend.clear_session() sess.close() return val_results, loss_results if np.isnan(loss_value): break val_results.append((train_elf.last_step,val)) tf.reset_default_graph() tf.keras.backend.clear_session() sess.close() return val_results, loss_results def validate(network_config, hyper_params, params, sess, dset, num_batches=10): """ Runs validation with current weights Loading