Loading stemdl/losses.py +18 −8 Original line number Diff line number Diff line Loading @@ -83,13 +83,13 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None pot_labels, probe_labels_re, probe_labels_im = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)] #weight= np.prod(pot_labels.shape.as_list()[-2:]) weight=None inv_str = hyper_params.get('inv_strength', 0.1) reg_str = hyper_params.get('reg_strength', 0.1) inv_str = hyper_params.get('inv_strength', 1) reg_str = hyper_params.get('reg_strength', 0.01) dec_str = hyper_params.get('dec_strength', 1) inverter_loss = inv_str * calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight) decoder_loss_im = dec_str * calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight) decoder_loss_re = dec_str * calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight) psi_out_mod = thin_object(probe_re, probe_im, pot) psi_out_mod = thin_object(probe_re, probe_im, pot, summarize=False) reg_loss = reg_str * calculate_loss_regressor(psi_out_mod, tf.reduce_mean(images, axis=[1], keepdims=True), params, hyper_params, weight=weight) tf.summary.scalar('reg_loss ', reg_loss) Loading @@ -111,7 +111,6 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None #Assemble all of the losses. losses = tf.get_collection(tf.GraphKeys.LOSSES) if hyper_params['network_type'] == 'YNet': reg_str = hyper_params.get('reg_strength', 0.1) losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_loss] # losses, prefac = ynet_adjusted_losses(losses, step) regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) Loading @@ -125,6 +124,16 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None return total_loss, loss_averages_op, losses return total_loss, loss_averages_op def get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=1): probe_im = tf.cast(n_net.model_output['decoder_IM'], tf.float32) probe_re = tf.cast(n_net.model_output['decoder_RE'], tf.float32) pot = tf.cast(n_net.model_output['inverter'], tf.float32) psi_out_mod = thin_object(probe_re, probe_im, pot) reg_loss = calculate_loss_regressor(psi_out_mod, tf.reduce_mean(psi_out_true, axis=[1], keepdims=True), params, hyper_params, weight=weight) reg_loss = tf.cast(reg_loss, tf.float32) return reg_loss def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None): input = tf.cast(tf.reshape(n_net.model_output,[batch_size, -1]), tf.float32) dim_input = input.shape[1].value Loading Loading @@ -198,7 +207,7 @@ def calculate_loss_regressor(net_output, labels, params, hyper_params, weight=No reduction=tf.losses.Reduction.MEAN) if loss_params['type'] == 'MSE': cost = tf.losses.mean_squared_error(labels, weights=weight, predictions=net_output, reduction=tf.losses.Reduction.SUM) reduction=tf.losses.Reduction.MEAN) if loss_params['type'] == 'ABS_DIFF': cost = tf.losses.absolute_difference(labels, weights=weight, predictions=net_output, reduction=tf.losses.Reduction.MEAN) Loading Loading @@ -251,7 +260,7 @@ def fftshift(tensor, tens_format='NCHW'): shift_tensor = manip_ops.roll(tensor, shift, dims) return shift_tensor def thin_object(psi_k_re, psi_k_im, potential): def thin_object(psi_k_re, psi_k_im, potential, summarize=True): # mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32) # ratio = 0 # if ratio == 0: Loading @@ -270,6 +279,7 @@ def thin_object(psi_k_re, psi_k_im, potential): psi_out = tf.fft2d(psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list())) psi_out_mod = tf.cast(tf.abs(psi_out), tf.float32) ** 2 psi_out_mod = tf.reduce_mean(psi_out_mod, axis=1, keep_dims=True) if summarize: tf.summary.image('Psi_k_out', tf.transpose(tf.abs(psi_out_mod)**0.25, perm=[0,2,3,1]), max_outputs=1) tf.summary.image('Psi_x_in', tf.transpose(tf.abs(psi_x)**0.25, perm=[0,2,3,1]), max_outputs=1) return psi_out_mod stemdl/network.py +3 −1 Original line number Diff line number Diff line Loading @@ -708,8 +708,10 @@ class ConvNet: return tf.nn.tanh(input, name=name) elif params['activation'] == 'leaky_relu': return tf.nn.leaky_relu(input, name=name) else: elif params['activation'] == 'relu': return tf.nn.relu(input, name=name) elif params['activation'] == 'none': return input else: return input Loading stemdl/optimizers.py +1 −1 Original line number Diff line number Diff line Loading @@ -312,7 +312,7 @@ def optimize_loss(loss, # Compute gradients. grads_and_vars = opt.compute_gradients( loss, colocate_gradients_with_ops=True, var_list=var_list loss, var_list=var_list ) if dtype == 'mixed' or dtype == tf.float16: Loading stemdl/runtime.py +52 −9 Original line number Diff line number Diff line Loading @@ -13,6 +13,7 @@ import numpy as np import math from itertools import chain from multiprocessing import cpu_count from copy import deepcopy #TF import tensorflow as tf Loading Loading @@ -543,7 +544,13 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): # Build it and propagate images through it. n_net.build_model() # # Stop gradients # stop_op = tf.stop_gradient(n_net.model_output['encoder']) # calculate the total loss # psi_out_true = tf.placeholder(tf.float32, shape=images.shape.as_list(), name="psi_out_true") psi_out_true = images constr_loss = losses.get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=10) total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary) #get summaries, except for the one produced by string_input_producer Loading Loading @@ -573,9 +580,21 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): loss_scaling=hyper_params.get('loss_scaling',1.0), skip_update_cond=skip_update_cond, on_horovod=True, model_scopes=n_net.scopes) # optimizer for regularization step # var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)] var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)] # var_list = None # print_rank(var_list) opt = tf.train.MomentumOptimizer(1e-5, 0.9) reg_opt = opt.minimize(constr_loss, var_list=var_list) # Gather all training related ops into a single one. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies([tf.group(*[stop_op, reg_opt, update_ops])]): with tf.control_dependencies([tf.group(*[reg_opt, update_ops])]): reg_train = tf.no_op(name='reg_train') # Gather all training related ops into a single one. increment_op = tf.assign_add(global_step, 1) ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step) all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op])) Loading Loading @@ -654,7 +673,10 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): loss_results = [] loss_value = 1e10 val = 1e10 current_batch = np.zeros(images.shape.as_list(), dtype=np.float32) batch_buffer = [] while train_elf.last_step < maxSteps : # batch_buffer.append(images.eval(session=sess)) train_elf.before_run() doLog = bool(train_elf.last_step % logFreq == 0) doSave = bool(train_elf.last_step % saveStep == 0) Loading @@ -663,17 +685,17 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): doValidate = bool(train_elf.last_step % validateStep == 0) doFinish = bool(train_elf.start_time - params['start_time'] > maxTime) if train_elf.last_step == 1 and params['debug']: summary = sess.run([train_op, summary_merged])[-1] _, summary, current_batch = sess.run([train_op, summary_merged, images]) train_elf.write_summaries( summary ) elif not doLog and not doSave and not doTrace and not doSumm: sess.run(train_op) _, current_batch = sess.run([train_op, images]) elif doLog and not doSave and not doSumm: _, lr, loss_value, aux_losses = sess.run( [ train_op, learning_rate, total_loss, indv_losses]) _, lr, loss_value, aux_losses, current_batch = sess.run( [ train_op, learning_rate, total_loss, indv_losses, images]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr) elif doLog and doSumm and doSave : _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ]) _, summary, loss_value, aux_losses, lr, current_batch = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate, images ]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr ) train_elf.write_summaries( summary ) Loading @@ -681,12 +703,12 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): saver.save(sess, checkpoint_file, global_step=train_elf.last_step) print_rank('Saved Checkpoint.') elif doLog and doSumm : _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ]) _, summary, loss_value, aux_losses, lr, current_batch = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate, images ]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr ) train_elf.write_summaries( summary ) elif doSumm: summary = sess.run([train_op, summary_merged])[-1] _, summary, current_batch = sess.run([train_op, summary_merged, images]) train_elf.write_summaries( summary ) elif doSave : if hvd.rank( ) == 0 : Loading @@ -709,6 +731,27 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): return val_results, loss_results if np.isnan(loss_value): break # if doLog: # constr_val = sess.run(constr_loss, feed_dict={psi_out_true:current_batch}) # print_rank('\t\tstep={}, current constr_loss={:2.3e}'.format(train_elf.last_step, constr_val)) # current_batch_list = [] batch_buffer.append(current_batch) # print_rank(len(batch_buffer)) if bool(train_elf.last_step % 10 == 0 and train_elf.last_step >= 10): for itr, current_batch in enumerate(batch_buffer): # noise = np.random.random(images.shape.as_list()[1:]) # noise = noise.astype(np.float32) # mix = 0.25 # current_batch = (1 - mix) * current_batch + mix * noise[np.newaxis] _, constr_val = sess.run([reg_train, constr_loss], feed_dict={psi_out_true:current_batch}) # if doLog: print_rank('\t\tstep={}, reg iter={}, constr_loss={:2.3e}'.format(train_elf.last_step, itr, constr_val)) # print('\t\trank={}, step={}, reg iter={}, constr_loss={:2.3e}'.format(hvd.rank(), train_elf.last_step, itr, constr_val)) del batch_buffer batch_buffer = [] # for i in range(len(IO_ops)): # sess.run(IO_ops[:i + 1]) val_results.append((train_elf.last_step,val)) tf.reset_default_graph() tf.keras.backend.clear_session() Loading Loading
stemdl/losses.py +18 −8 Original line number Diff line number Diff line Loading @@ -83,13 +83,13 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None pot_labels, probe_labels_re, probe_labels_im = [tf.expand_dims(itm, axis=1) for itm in tf.unstack(labels, axis=1)] #weight= np.prod(pot_labels.shape.as_list()[-2:]) weight=None inv_str = hyper_params.get('inv_strength', 0.1) reg_str = hyper_params.get('reg_strength', 0.1) inv_str = hyper_params.get('inv_strength', 1) reg_str = hyper_params.get('reg_strength', 0.01) dec_str = hyper_params.get('dec_strength', 1) inverter_loss = inv_str * calculate_loss_regressor(pot, pot_labels, params, hyper_params, weight=weight) decoder_loss_im = dec_str * calculate_loss_regressor(probe_im, probe_labels_im, params, hyper_params, weight=weight) decoder_loss_re = dec_str * calculate_loss_regressor(probe_re, probe_labels_re, params, hyper_params, weight=weight) psi_out_mod = thin_object(probe_re, probe_im, pot) psi_out_mod = thin_object(probe_re, probe_im, pot, summarize=False) reg_loss = reg_str * calculate_loss_regressor(psi_out_mod, tf.reduce_mean(images, axis=[1], keepdims=True), params, hyper_params, weight=weight) tf.summary.scalar('reg_loss ', reg_loss) Loading @@ -111,7 +111,6 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None #Assemble all of the losses. losses = tf.get_collection(tf.GraphKeys.LOSSES) if hyper_params['network_type'] == 'YNet': reg_str = hyper_params.get('reg_strength', 0.1) losses = [inverter_loss , decoder_loss_re, decoder_loss_im, reg_loss] # losses, prefac = ynet_adjusted_losses(losses, step) regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) Loading @@ -125,6 +124,16 @@ def calc_loss(n_net, scope, hyper_params, params, labels, step=None, images=None return total_loss, loss_averages_op, losses return total_loss, loss_averages_op def get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=1): probe_im = tf.cast(n_net.model_output['decoder_IM'], tf.float32) probe_re = tf.cast(n_net.model_output['decoder_RE'], tf.float32) pot = tf.cast(n_net.model_output['inverter'], tf.float32) psi_out_mod = thin_object(probe_re, probe_im, pot) reg_loss = calculate_loss_regressor(psi_out_mod, tf.reduce_mean(psi_out_true, axis=[1], keepdims=True), params, hyper_params, weight=weight) reg_loss = tf.cast(reg_loss, tf.float32) return reg_loss def fully_connected(n_net, layer_params, batch_size, wd=0, name=None, reuse=None): input = tf.cast(tf.reshape(n_net.model_output,[batch_size, -1]), tf.float32) dim_input = input.shape[1].value Loading Loading @@ -198,7 +207,7 @@ def calculate_loss_regressor(net_output, labels, params, hyper_params, weight=No reduction=tf.losses.Reduction.MEAN) if loss_params['type'] == 'MSE': cost = tf.losses.mean_squared_error(labels, weights=weight, predictions=net_output, reduction=tf.losses.Reduction.SUM) reduction=tf.losses.Reduction.MEAN) if loss_params['type'] == 'ABS_DIFF': cost = tf.losses.absolute_difference(labels, weights=weight, predictions=net_output, reduction=tf.losses.Reduction.MEAN) Loading Loading @@ -251,7 +260,7 @@ def fftshift(tensor, tens_format='NCHW'): shift_tensor = manip_ops.roll(tensor, shift, dims) return shift_tensor def thin_object(psi_k_re, psi_k_im, potential): def thin_object(psi_k_re, psi_k_im, potential, summarize=True): # mask = np.zeros(psi_k_re.shape.as_list(), dtype=np.float32) # ratio = 0 # if ratio == 0: Loading @@ -270,6 +279,7 @@ def thin_object(psi_k_re, psi_k_im, potential): psi_out = tf.fft2d(psi_x_stack * pot_frac / np.prod(psi_x.shape.as_list())) psi_out_mod = tf.cast(tf.abs(psi_out), tf.float32) ** 2 psi_out_mod = tf.reduce_mean(psi_out_mod, axis=1, keep_dims=True) if summarize: tf.summary.image('Psi_k_out', tf.transpose(tf.abs(psi_out_mod)**0.25, perm=[0,2,3,1]), max_outputs=1) tf.summary.image('Psi_x_in', tf.transpose(tf.abs(psi_x)**0.25, perm=[0,2,3,1]), max_outputs=1) return psi_out_mod
stemdl/network.py +3 −1 Original line number Diff line number Diff line Loading @@ -708,8 +708,10 @@ class ConvNet: return tf.nn.tanh(input, name=name) elif params['activation'] == 'leaky_relu': return tf.nn.leaky_relu(input, name=name) else: elif params['activation'] == 'relu': return tf.nn.relu(input, name=name) elif params['activation'] == 'none': return input else: return input Loading
stemdl/optimizers.py +1 −1 Original line number Diff line number Diff line Loading @@ -312,7 +312,7 @@ def optimize_loss(loss, # Compute gradients. grads_and_vars = opt.compute_gradients( loss, colocate_gradients_with_ops=True, var_list=var_list loss, var_list=var_list ) if dtype == 'mixed' or dtype == tf.float16: Loading
stemdl/runtime.py +52 −9 Original line number Diff line number Diff line Loading @@ -13,6 +13,7 @@ import numpy as np import math from itertools import chain from multiprocessing import cpu_count from copy import deepcopy #TF import tensorflow as tf Loading Loading @@ -543,7 +544,13 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): # Build it and propagate images through it. n_net.build_model() # # Stop gradients # stop_op = tf.stop_gradient(n_net.model_output['encoder']) # calculate the total loss # psi_out_true = tf.placeholder(tf.float32, shape=images.shape.as_list(), name="psi_out_true") psi_out_true = images constr_loss = losses.get_YNet_constraint(n_net, hyper_params, params, psi_out_true, weight=10) total_loss, _, indv_losses = losses.calc_loss(n_net, scope, hyper_params, params, labels, step=global_step, images=images, summary=summary) #get summaries, except for the one produced by string_input_producer Loading Loading @@ -573,9 +580,21 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): loss_scaling=hyper_params.get('loss_scaling',1.0), skip_update_cond=skip_update_cond, on_horovod=True, model_scopes=n_net.scopes) # optimizer for regularization step # var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)] var_list = [itm for itm in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'CVAE' not in str(itm.name)] # var_list = None # print_rank(var_list) opt = tf.train.MomentumOptimizer(1e-5, 0.9) reg_opt = opt.minimize(constr_loss, var_list=var_list) # Gather all training related ops into a single one. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies([tf.group(*[stop_op, reg_opt, update_ops])]): with tf.control_dependencies([tf.group(*[reg_opt, update_ops])]): reg_train = tf.no_op(name='reg_train') # Gather all training related ops into a single one. increment_op = tf.assign_add(global_step, 1) ema = tf.train.ExponentialMovingAverage(decay=0.9, num_updates=global_step) all_ops = tf.group(*([train_opt] + update_ops + IO_ops + [increment_op])) Loading Loading @@ -654,7 +673,10 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): loss_results = [] loss_value = 1e10 val = 1e10 current_batch = np.zeros(images.shape.as_list(), dtype=np.float32) batch_buffer = [] while train_elf.last_step < maxSteps : # batch_buffer.append(images.eval(session=sess)) train_elf.before_run() doLog = bool(train_elf.last_step % logFreq == 0) doSave = bool(train_elf.last_step % saveStep == 0) Loading @@ -663,17 +685,17 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): doValidate = bool(train_elf.last_step % validateStep == 0) doFinish = bool(train_elf.start_time - params['start_time'] > maxTime) if train_elf.last_step == 1 and params['debug']: summary = sess.run([train_op, summary_merged])[-1] _, summary, current_batch = sess.run([train_op, summary_merged, images]) train_elf.write_summaries( summary ) elif not doLog and not doSave and not doTrace and not doSumm: sess.run(train_op) _, current_batch = sess.run([train_op, images]) elif doLog and not doSave and not doSumm: _, lr, loss_value, aux_losses = sess.run( [ train_op, learning_rate, total_loss, indv_losses]) _, lr, loss_value, aux_losses, current_batch = sess.run( [ train_op, learning_rate, total_loss, indv_losses, images]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr) elif doLog and doSumm and doSave : _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ]) _, summary, loss_value, aux_losses, lr, current_batch = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate, images ]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr ) train_elf.write_summaries( summary ) Loading @@ -681,12 +703,12 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): saver.save(sess, checkpoint_file, global_step=train_elf.last_step) print_rank('Saved Checkpoint.') elif doLog and doSumm : _, summary, loss_value, aux_losses, lr = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate ]) _, summary, loss_value, aux_losses, lr, current_batch = sess.run( [ train_op, summary_merged, total_loss, indv_losses, learning_rate, images ]) loss_results.append((train_elf.last_step, loss_value)) train_elf.log_stats( loss_value, aux_losses, lr ) train_elf.write_summaries( summary ) elif doSumm: summary = sess.run([train_op, summary_merged])[-1] _, summary, current_batch = sess.run([train_op, summary_merged, images]) train_elf.write_summaries( summary ) elif doSave : if hvd.rank( ) == 0 : Loading @@ -709,6 +731,27 @@ def train_YNet(network_config, hyper_params, params, gpu_id=None): return val_results, loss_results if np.isnan(loss_value): break # if doLog: # constr_val = sess.run(constr_loss, feed_dict={psi_out_true:current_batch}) # print_rank('\t\tstep={}, current constr_loss={:2.3e}'.format(train_elf.last_step, constr_val)) # current_batch_list = [] batch_buffer.append(current_batch) # print_rank(len(batch_buffer)) if bool(train_elf.last_step % 10 == 0 and train_elf.last_step >= 10): for itr, current_batch in enumerate(batch_buffer): # noise = np.random.random(images.shape.as_list()[1:]) # noise = noise.astype(np.float32) # mix = 0.25 # current_batch = (1 - mix) * current_batch + mix * noise[np.newaxis] _, constr_val = sess.run([reg_train, constr_loss], feed_dict={psi_out_true:current_batch}) # if doLog: print_rank('\t\tstep={}, reg iter={}, constr_loss={:2.3e}'.format(train_elf.last_step, itr, constr_val)) # print('\t\trank={}, step={}, reg iter={}, constr_loss={:2.3e}'.format(hvd.rank(), train_elf.last_step, itr, constr_val)) del batch_buffer batch_buffer = [] # for i in range(len(IO_ops)): # sess.run(IO_ops[:i + 1]) val_results.append((train_elf.last_step,val)) tf.reset_default_graph() tf.keras.backend.clear_session() Loading