Commit ae82905a authored by Laanait, Nouamane's avatar Laanait, Nouamane

tweaks to ynet unsupervised inner loop

parent 67c84dc1
Pipeline #82327 failed with stage
in 2 minutes and 14 seconds
......@@ -132,7 +132,10 @@ def get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=1):
reg_loss = calculate_loss_regressor(psi_out_mod, tf.reduce_mean(psi_out_true, axis=[1], keepdims=True),
params, hyper_params, weight=weight)
reg_loss = tf.cast(reg_loss, tf.float32)
return reg_loss
regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
reg_total_loss = tf.add_n([reg_loss] + regularization, name='total_loss')
reg_totat_loss = tf.cast(reg_total_loss, tf.float32)
return reg_total_loss
def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None):
input = tf.cast(tf.reshape(n_net.model_output,[batch_size, -1]), tf.float32)
......
......@@ -548,9 +548,8 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
# stop_op = tf.stop_gradient(n_net.model_output['encoder'])
# calculate the total loss
# psi_out_true = tf.placeholder(tf.float32, shape=images.shape.as_list(), name="psi_out_true")
psi_out_true = images
constr_loss = losses.get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=10)
constr_loss = losses.get_YNet_constraint(n_net, hyper_params, params, images, weight=10)
total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary)
#get summaries, except for the one produced by string_input_producer
......@@ -560,48 +559,50 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
#######################################
# Apply Gradients and setup train op #
#######################################
# get learning policy
def learning_policy_func(step):
return lr_policies.decay_warmup(params, hyper_params, step)
## TODO: implement other policies in lr_policies
# optimizer for unsupervised step
var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' in str(itm.name)]
reg_hyper = deepcopy(hyper_params)
reg_hyper['initial_learning_rate'] = 1e-1
def learning_policy_func_reg(step):
return lr_policies.decay_warmup(params, reg_hyper, step)
iter_size = params.get('accumulate_step', 0)
skip_update_cond = tf.cast(tf.floormod(global_step, tf.constant(iter_size, dtype=tf.int32)), tf.bool)
if params['IMAGE_FP16']:
opt_type='mixed'
else:
opt_type=tf.float32
# setup optimizer
opt_dict = hyper_params['optimization']['params']
reg_opt, learning_rate = optimizers.optimize_loss(constr_loss, 'Momentum',
{'momentum': 0.9}, learning_policy_func_reg, var_list=var_list, run_params=params, hyper_params=reg_hyper, iter_size=iter_size, dtype=opt_type,
loss_scaling=1.0,
skip_update_cond=skip_update_cond,
on_horovod=True, model_scopes=None)
# optimizer for supervised step
def learning_policy_func(step):
return lr_policies.decay_warmup(params, hyper_params, step)
## TODO: implement other policies in lr_policies
opt_dict = hyper_params['optimization']['params']
train_opt, learning_rate = optimizers.optimize_loss(total_loss, hyper_params['optimization']['name'],
opt_dict, learning_policy_func, run_params=params, hyper_params=hyper_params, iter_size=iter_size, dtype=opt_type,
loss_scaling=hyper_params.get('loss_scaling',1.0),
skip_update_cond=skip_update_cond,
on_horovod=True, model_scopes=n_net.scopes)
# optimizer for regularization step
# var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)]
var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)]
# var_list = None
# print_rank(var_list)
opt = tf.train.MomentumOptimizer(1e-5, 0.9)
reg_opt = opt.minimize(constr_loss, var_list=var_list)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# with tf.control_dependencies([tf.group(*[stop_op, reg_opt, update_ops])]):
# Gather unsupervised training ops
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step)
increment_op = tf.assign_add(global_step, 1)
with tf.control_dependencies([tf.group(*[reg_opt, update_ops])]):
reg_train = tf.no_op(name='reg_train')
reg_op = ema.apply(var_list=var_list)
# Gather all training related ops into a single one.
# Gather supervised training related ops into a single one.
increment_op = tf.assign_add(global_step, 1)
ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step)
all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op]))
with tf.control_dependencies([all_ops]):
train_op = ema.apply(tf.trainable_variables())
# train_op = tf.no_op(name='train')
########################
# Setting up Summaries #
......@@ -668,7 +669,7 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
logFreq = params[ 'log_frequency' ]
traceStep = params[ 'trace_step' ]
maxTime = params.get('max_time', 1e12)
inner_loop = params.get('inner_iter', 1e12)
inner_loop = hyper_params.get('inner_iter', 1e12)
val_results = []
loss_results = []
......@@ -732,27 +733,15 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None):
return val_results, loss_results
if np.isnan(loss_value):
break
# if doLog:
# constr_val = sess.run(constr_loss, feed_dict={psi_out_true:current_batch})
# print_rank('\t\tstep={}, current constr_loss={:2.3e}'.format(train_elf.last_step, constr_val))
# current_batch_list = []
if inner_loop < 100:
batch_buffer.append(current_batch)
# print_rank(len(batch_buffer))
if bool(train_elf.last_step % inner_loop == 0 and train_elf.last_step >= 10):
for itr, current_batch in enumerate(batch_buffer):
# noise = np.random.random(images.shape.as_list()[1:])
# noise = noise.astype(np.float32)
# mix = 0.25
# current_batch = (1 - mix) * current_batch + mix * noise[np.newaxis]
_, constr_val = sess.run([reg_train, constr_loss], feed_dict={psi_out_true:current_batch})
# if doLog:
print_rank('\t\tstep={}, reg iter={}, constr_loss={:2.3e}'.format(train_elf.last_step, itr, constr_val))
# print('\t\trank={}, step={}, reg iter={}, constr_loss={:2.3e}'.format(hvd.rank(), train_elf.last_step, itr, constr_val))
_, constr_val = sess.run([reg_op, constr_loss], feed_dict={psi_out_true:current_batch})
if doLog:
print_rank('\t\tstep={}, reg iter={}, constr_loss={:2.3e}'.format(train_elf.last_step, itr, constr_val))
del batch_buffer
batch_buffer = []
# for i in range(len(IO_ops)):
# sess.run(IO_ops[:i + 1])
val_results.append((train_elf.last_step,val))
tf.reset_default_graph()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment