Commit 6d8ddd4d authored by Laanait, Nouamane's avatar Laanait, Nouamane
Browse files

new training function and class for YNet

parent 10d37686
Pipeline #80781 failed with stage
in 2 minutes and 34 seconds
......@@ -118,6 +118,8 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None
total_loss = tf.cast(total_loss, tf.float32)
# Generate summaries for the losses and get corresponding op
loss_averages_op = _add_loss_summaries(total_loss, losses, summaries=summary)
if hyper_params['network_type'] == 'YNet':
return total_loss, loss_averages_op, losses
return total_loss, loss_averages_op
def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None):
......
......@@ -48,7 +48,7 @@ def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
variable = tf.cast(variable, dtype)
return variable
class TrainHelper(object):
class TrainHelper:
def __init__(self, params, saver, writer, net_ops, last_step=0, log_freq=1):
self.params = params
self.last_step = last_step
......@@ -127,6 +127,22 @@ class TrainHelper(object):
print_rank('loss is nan...')
# sys.exit(0)
class TrainHelper_YNet(TrainHelper):
def log_stats(self, loss_value, aux_losses, learning_rate):
t = time.time( )
duration = t - self.start_time
examples_per_sec = self.params['batch_size'] * hvd.size() / duration
self.cumm_time = (time.time() - self.cumm_time)/self.log_freq
flops = self.net_ops * examples_per_sec
avg_flops = self.net_ops * self.params['batch_size'] * hvd.size() / self.cumm_time
loss_inv, loss_dec_re, loss_dec_im, loss_reg = aux_losses
self.nanloss(loss_value)
format_str = (
'time= %.1f, step= %2.2e, epoch= %2.2e, lr= %.2e, loss=%.3e, loss_inv= %.2e, loss_dec_im=%.2e, loss_dec_re=%.2e, loss_reg=%.2e, step_time= %2.2f sec, ranks= %d, examples/sec= %.1f')
print_rank(format_str % ( t - self.params[ 'start_time' ], self.last_step, self.elapsed_epochs,
learning_rate, loss_value, loss_inv, loss_dec_im, loss_dec_re, loss_reg, duration, hvd.size(), examples_per_sec))
self.cumm_time = time.time()
def print_rank(*args, **kwargs):
if hvd.rank() == 0:
print(*args, **kwargs)
......@@ -424,6 +440,281 @@ def train(network_config, hyper_params, params, gpu_id=None):
sess.close()
return val_results, loss_results
def train_YNet(network_config, hyper_params, params, gpu_id=None):
"""
Train the network for a number of steps using horovod and asynchronous I/O staging ops.
:param network_config: OrderedDict, network configuration
:param hyper_params: OrderedDict, hyper_parameters
:param params: dict
:return: None
"""
#########################
# Start Session #
#########################
# Config file for tf.Session()
config = tf.ConfigProto(allow_soft_placement=params['allow_soft_placement'],
log_device_placement=params['log_device_placement'],
)
config.gpu_options.allow_growth = True
if gpu_id is None:
gpu_id = hvd.local_rank()
config.gpu_options.visible_device_list = str(gpu_id)
config.gpu_options.force_gpu_compatible = True
config.intra_op_parallelism_threads = 6
config.inter_op_parallelism_threads = max(1, cpu_count()//6)
#config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
#jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
# JIT causes gcc errors on dgx-dl and is built without on Summit.
sess = tf.Session(config=config)
############################
# Setting up Checkpointing #
###########################
last_step = 0
if params[ 'restart' ] :
# Check if training is a restart from checkpoint
ckpt = tf.train.get_checkpoint_state(params[ 'checkpt_dir' ] )
if ckpt is None :
print_rank( '<ERROR> Could not restart from checkpoint %s' % params[ 'checkpt_dir' ])
else :
last_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
print_rank("Restoring from previous checkpoint @ step=%d" %last_step)
global_step = tf.Variable(last_step, name='global_step',trainable=False)
############################################
# Setup Graph, Input pipeline and optimizer#
############################################
# Start building the graph
# Setup data stream
with tf.device(params['CPU_ID']):
with tf.name_scope('Input') as _:
if params['filetype'] == 'tfrecord':
dset = inputs.DatasetTFRecords(params, dataset=params['dataset'], debug=False)
elif params['filetype'] == 'lmdb':
dset = inputs.DatasetLMDB(params, dataset=params['dataset'], debug=params['debug'])
images, labels = dset.minibatch()
# Staging images on host
staging_op, (images, labels) = dset.stage([images, labels])
with tf.device('/gpu:%d' % hvd.local_rank()):
# Copy images from host to device
gpucopy_op, (images, labels) = dset.stage([images, labels])
IO_ops = [staging_op, gpucopy_op]
##################
# Building Model#
##################
# Build model, forward propagate, and calculate loss
scope = 'model'
summary = False
if params['debug']:
summary = True
print_rank('Starting up queue of images+labels: %s, %s ' % (format(images.get_shape()),
format(labels.get_shape())))
with tf.variable_scope(scope,
# Force all variables to be stored as float32
custom_getter=float32_variable_storage_getter) as _:
# Setup Neural Net
n_net = network.YNet(scope, params, hyper_params, network_config, images, labels,
operation='train', summary=summary, verbose=True)
###### XLA compilation #########
#if params['network_class'] == 'fcdensenet':
# def wrap_n_net(*args):
# images, labels = args
# n_net = network.FCDenseNet(scope, params, hyper_params, network_config, images, labels,
# operation='train', summary=False, verbose=True)
# n_net.build_model()
# return n_net.model_output
#
# n_net.model_output = xla.compile(wrap_n_net, inputs=[images, labels])
##############################
# Build it and propagate images through it.
n_net.build_model()
# calculate the total loss
total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary)
#get summaries, except for the one produced by string_input_producer
if summary: summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
# print_rank([scope.name for scope in n_net.scopes])
#######################################
# Apply Gradients and setup train op #
#######################################
# get learning policy
def learning_policy_func(step):
return lr_policies.decay_warmup(params, hyper_params, step)
## TODO: implement other policies in lr_policies
iter_size = params.get('accumulate_step', 0)
skip_update_cond = tf.cast(tf.floormod(global_step, tf.constant(iter_size, dtype=tf.int32)), tf.bool)
if params['IMAGE_FP16']:
opt_type='mixed'
else:
opt_type=tf.float32
# setup optimizer
opt_dict = hyper_params['optimization']['params']
train_opt, learning_rate = optimizers.optimize_loss(total_loss, hyper_params['optimization']['name'],
opt_dict, learning_policy_func, run_params=params, hyper_params=hyper_params, iter_size=iter_size, dtype=opt_type,
loss_scaling=hyper_params.get('loss_scaling',1.0),
skip_update_cond=skip_update_cond,
on_horovod=True, model_scopes=n_net.scopes)
# Gather all training related ops into a single one.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
increment_op = tf.assign_add(global_step, 1)
ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step)
all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op]))
with tf.control_dependencies([all_ops]):
train_op = ema.apply(tf.trainable_variables())
# train_op = tf.no_op(name='train')
########################
# Setting up Summaries #
########################
# Stats and summaries
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
# if hvd.rank() == 0:
summary_writer = tf.summary.FileWriter(os.path.join(params['checkpt_dir'], str(hvd.rank())), sess.graph)
# Add Summary histograms for trainable variables and their gradients
if params['debug']:
predic_inverter = tf.transpose(n_net.model_output['inverter'], perm=[0,2,3,1])
tf.summary.image("output_inverter", predic_inverter, max_outputs=2)
predic_decoder_RE = tf.transpose(n_net.model_output['decoder_RE'], perm=[0,2,3,1])
predic_decoder_IM = tf.transpose(n_net.model_output['decoder_IM'], perm=[0,2,3,1])
tf.summary.image("output_decoder_RE", predic_decoder_RE, max_outputs=2)
tf.summary.image("output_decoder_IM", predic_decoder_IM, max_outputs=2)
new_labels = tf.unstack(labels, axis=1)
for label, tag in zip(new_labels, ['potential', 'probe_RE', 'probe_IM']):
label = tf.expand_dims(label, axis=-1)
# label = tf.transpose(label, perm=[0,2,3,1])
tf.summary.image(tag, label, max_outputs=2)
tf.summary.image("inputs", tf.transpose(tf.reduce_mean(images, axis=1, keepdims=True), perm=[0,2,3,1]), max_outputs=4)
summary_merged = tf.summary.merge_all()
###############################
# Setting up training session #
###############################
#Initialize variables
init_op = tf.global_variables_initializer()
sess.run(init_op)
# Sync
print_rank('Syncing horovod ranks...')
sync_op = hvd.broadcast_global_variables(0)
sess.run(sync_op)
# prefill pipeline first
print_rank('Prefilling I/O pipeline...')
for i in range(len(IO_ops)):
sess.run(IO_ops[:i + 1])
# Saver and Checkpoint restore
checkpoint_file = os.path.join(params[ 'checkpt_dir' ], 'model.ckpt')
saver = tf.train.Saver(max_to_keep=None, save_relative_paths=True)
# Check if training is a restart from checkpoint
if params['restart'] and ckpt is not None:
saver.restore(sess, ckpt.model_checkpoint_path)
print_rank("Restoring from previous checkpoint @ step=%d" % last_step)
# Train
train_elf = TrainHelper_YNet(params, saver, summary_writer, n_net.get_ops(), last_step=last_step, log_freq=params['log_frequency'])
saveStep = params['save_step']
validateStep = params['validate_step']
summaryStep = params['summary_step']
train_elf.run_summary()
maxSteps = params[ 'max_steps' ]
logFreq = params[ 'log_frequency' ]
traceStep = params[ 'trace_step' ]
maxTime = params.get('max_time', 1e12)
val_results = []
loss_results = []
loss_value = 1e10
val = 1e10
while train_elf.last_step < maxSteps :
train_elf.before_run()
doLog = bool(train_elf.last_step % logFreq == 0)
doSave = bool(train_elf.last_step % saveStep == 0)
doSumm = bool(train_elf.last_step % summaryStep == 0 and params['debug'])
doTrace = bool(train_elf.last_step == traceStep and params['gpu_trace'])
doValidate = bool(train_elf.last_step % validateStep == 0)
doFinish = bool(train_elf.start_time - params['start_time'] > maxTime)
if train_elf.last_step == 1 and params['debug']:
summary = sess.run([train_op, summary_merged])[-1]
train_elf.write_summaries( summary )
elif not doLog and not doSave and not doTrace and not doSumm:
sess.run(train_op)
elif doLog and not doSave and not doSumm:
_, lr, loss_value, aux_losses = sess.run( [ train_op, learning_rate, total_loss, indv_losses])
loss_results.append((train_elf.last_step, loss_value))
train_elf.log_stats( loss_value, aux_losses, lr)
elif doLog and doSumm and doSave :
_, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses,
learning_rate ])
loss_results.append((train_elf.last_step, loss_value))
train_elf.log_stats( loss_value, aux_losses, lr )
train_elf.write_summaries( summary )
if hvd.rank( ) == 0 :
saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
print_rank('Saved Checkpoint.')
elif doLog and doSumm :
_, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ])
loss_results.append((train_elf.last_step, loss_value))
train_elf.log_stats( loss_value, aux_losses, lr )
train_elf.write_summaries( summary )
elif doSumm:
summary = sess.run([train_op, summary_merged])[-1]
train_elf.write_summaries( summary )
elif doSave :
if hvd.rank( ) == 0 :
saver.save(sess, checkpoint_file, global_step=train_elf.last_step)
print_rank('Saved Checkpoint.')
elif doTrace :
sess.run(train_op, options=run_options, run_metadata=run_metadata)
train_elf.save_trace(run_metadata, params[ 'trace_dir' ], params[ 'trace_step' ] )
train_elf.before_run()
# Here we do validation:
if doValidate:
val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
val_results.append((train_elf.last_step,val))
if doFinish:
#val = validate(network_config, hyper_params, params, sess, dset, num_batches=50)
#val_results.append((train_elf.last_step, val))
tf.reset_default_graph()
tf.keras.backend.clear_session()
sess.close()
return val_results, loss_results
if np.isnan(loss_value):
break
val_results.append((train_elf.last_step,val))
tf.reset_default_graph()
tf.keras.backend.clear_session()
sess.close()
return val_results, loss_results
def validate(network_config, hyper_params, params, sess, dset, num_batches=10):
"""
Runs validation with current weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment