Commit 10d37686 authored by Laanait, Nouamane's avatar Laanait, Nouamane

using layer normalization instead of batch norm for YNet

parent 81ddd819
Pipeline #80778 failed with stage
in 2 minutes and 41 seconds
......@@ -543,7 +543,8 @@ class DatasetLMDB(DatasetTFRecords):
images_newshape = [self.params['batch_size']] + self.data_specs['image_shape']
labels = tf.reshape(labels, labels_newshape)
images = tf.reshape(images, images_newshape)
labels = self.image_scaling(labels)
#labels = self.image_scaling(labels)
images = self.image_scaling(images)
# data augmentation
if self.params[self.mode + '_distort']:
with tf.device('/gpu:%i' % hvd.local_rank()):
......
......@@ -51,13 +51,13 @@ def decay_warmup(params, hyper_params, global_step):
return lr
warm_up_slope = hyper_params.get('warm_up_slope', 1.)
warm_up_slope = hyper_params.get('warm_up_slope', 0)
def constant_ramp():
lr = tf.cast(INITIAL_LEARNING_RATE, tf.float32) * (tf.cast(global_step, tf.float32) * warm_up_slope + 1)
lr = tf.math.minimum(tf.cast(WARM_UP_LEARNING_RATE_MAX, tf.float32), lr)
return lr
if hyper_params['warm_up']:
if warm_up_slope > 1e-2:
# LEARNING_RATE = tf.cond(global_step < ramp_up_steps, ramp, lambda: decay(ramp()))
#LEARNING_RATE = tf.cond(global_step < ramp_up_steps, linear_ramp, lambda: decay(linear_ramp()))
ramp_up_steps = tf.cast(WARM_UP_LEARNING_RATE_MAX/INITIAL_LEARNING_RATE, global_step.dtype)
......
......@@ -607,15 +607,15 @@ class ConvNet:
'beta': tf.constant_initializer(0.0, dtype=tf.float16),
'gamma': tf.constant_initializer(0.1, dtype=tf.float16),
}
if self.params['IMAGE_FP16']:
input= tf.cast(input, tf.float32)
# with tf.variable_scope('layer_normalization', reuse=None) as scope:
# output = tf.keras.layers.LayerNormalization(trainable=False)(inputs=input)
mean , variance = tf.nn.moments(input, axes=[2,3], keep_dims=True)
output = (input - mean)/ (tf.sqrt(variance) + 1e-7)
if self.params['IMAGE_FP16']:
output = tf.cast(output, tf.float16)
#output = tf.contrib.layers.batch_norm(input, decay=decay, scale=True, epsilon=epsilon,zero_debias_moving_mean=False,is_training=is_training,fused=True,data_format='NCHW',renorm=False,param_initializers=param_initializers)
if self.hyper_params['network_type'] == 'YNet':
if self.params['IMAGE_FP16']:
input= tf.cast(input, tf.float32)
mean , variance = tf.nn.moments(input, axes=[2,3], keep_dims=True)
output = (input - mean)/ (tf.sqrt(variance) + 1e-7)
if self.params['IMAGE_FP16']:
output = tf.cast(output, tf.float16)
else:
output = tf.contrib.layers.batch_norm(input, decay=decay, scale=True, epsilon=epsilon,zero_debias_moving_mean=False,is_training=is_training,fused=True,data_format='NCHW',renorm=False,param_initializers=param_initializers)
#output = tf.contrib.layers.batch_norm(input, decay=decay, scale=True, epsilon=epsilon,zero_debias_moving_mean=False,is_training=is_training,fused=True,data_format='NCHW',renorm=False)
# output = input
# Keep tabs on the number of weights
......@@ -1247,14 +1247,14 @@ class FCDenseNet(ConvNet):
def _freq2space(self, inputs=None):
shape = inputs.shape
weights_dim = 256
weights_dim = 256
num_fc = 2
# if weights_dim < 4096 :
fully_connected = OrderedDict({'type': 'fully_connected','weights': weights_dim,'bias': weights_dim, 'activation': 'relu',
fully_connected = OrderedDict({'type': 'fully_connected','weights': weights_dim,'bias': weights_dim, 'activation': 'leaky_relu',
'regularize': True})
deconv = OrderedDict({'type': "deconv_2D", 'stride': [2, 2], 'kernel': [3,3], 'features': 8, 'padding': 'SAME', 'upsample': 2})
deconv = OrderedDict({'type': "deconv_2D", 'stride': [2, 2], 'kernel': [4,4], 'features': 8, 'padding': 'SAME', 'upsample': 2})
conv = OrderedDict({'type': "conv_2D", 'stride': [1, 1], 'kernel': [3,3], 'features': inputs.shape.as_list()[1],
'padding': 'SAME', 'activation': 'relu', 'dropout':0.5})
'padding': 'SAME', 'activation': 'leaky_relu', 'dropout':0.5})
conv_1by1 = OrderedDict({'type': "conv_2D", 'stride': [1, 1], 'kernel': [3,3], 'features': 1, 'padding': 'SAME',
'activation': 'relu', 'dropout':0.5})
with tf.variable_scope('Freq2Space', reuse=tf.AUTO_REUSE) as _ :
......@@ -1266,8 +1266,8 @@ class FCDenseNet(ConvNet):
out = tf.reshape(out, [shape[0], -1])
for i in range(num_fc):
if i > 0:
#weights_dim = min(4096, int(shape.as_list()[-2]**2))
weights_dim = min(weights_dim, int(shape.as_list()[-2]**2))
weights_dim = min(4096, int(shape.as_list()[-2]**2))
#weights_dim = min(weights_dim, int(shape.as_list()[-2]**2))
fully_connected['weights'] = weights_dim
fully_connected['bias'] = weights_dim
with tf.variable_scope('FC_%d' %i, reuse=self.reuse) as _ :
......@@ -1281,6 +1281,7 @@ class FCDenseNet(ConvNet):
conv_1by1_n['stride'] = [1,1]
conv_1by1_n['features'] = 16
# self.print_rank("num_upsamp=", num_upsamp)
#if False:
if num_upsamp >= 0:
for up in range(num_upsamp):
with tf.variable_scope('deconv_upscale_%d' % up, reuse=self.reuse) as _:
......@@ -1288,14 +1289,15 @@ class FCDenseNet(ConvNet):
with tf.variable_scope('conv_upscale_%d' % up, reuse=self.reuse) as _:
out, _ = self._conv(input=out, params=conv_1by1_n)
out = self._batch_norm(input=out)
out = self._activate(input=out, params=conv_1by1_n)
#out = self._activate(input=out, params=conv_1by1_n)
rate = conv_1by1_n.get('dropout', 0)
out = tf.keras.layers.SpatialDropout2D(rate, data_format='channels_first')(inputs=out, training= self.operation == 'train')
# self.print_rank(" out shape after upscale+conv", out.shape)
else:
out = tf.transpose(out, perm=[0,2,3,1])
out = tf.image.resize_images(out, [shape[2], shape[3]])
out = tf.cast(tf.transpose(out, perm=[0,3,1,2]), tf.float16)
out = tf.transpose(out, perm=[0,3,1,2])
#out = tf.cast(tf.transpose(out, perm=[0,3,1,2]), tf.float16)
with tf.variable_scope('conv_restore', reuse=tf.AUTO_REUSE) as _ :
out, _ = self._conv(input=out, params=conv)
......@@ -1878,7 +1880,7 @@ class FCNet(ConvNet):
if fc_params['type'] == 'fully_connected':
if fc_params['activation'] == 'tanh':
lin_initializer.factor = 1.15
elif fc_params['activation'] == 'relu':
elif fc_params['activation'] == 'relu' or fc_params['activation'] == 'leaky_relu':
lin_initializer.factor = 1.43
elif fc_params['type'] == 'linear_output':
lin_initializer.factor = 1.0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment