...
 
Commits (3)
...@@ -548,10 +548,13 @@ class DatasetLMDB(DatasetTFRecords): ...@@ -548,10 +548,13 @@ class DatasetLMDB(DatasetTFRecords):
# data augmentation # data augmentation
if self.params[self.mode + '_distort']: if self.params[self.mode + '_distort']:
with tf.device('/gpu:%i' % hvd.local_rank()): with tf.device('/gpu:%i' % hvd.local_rank()):
images = tf.transpose(images, perm=[0,2,3,1]) if self.params.get('random_crop', False):
images = self.random_crop_resize(images) images = tf.transpose(images, perm=[0,2,3,1])
images = self.add_noise_image(images) images = self.random_crop_resize(images)
images = tf.transpose(images, perm=[0,3,1,2]) images = self.add_noise_image(images)
images = tf.transpose(images, perm=[0,3,1,2])
else:
images = self.add_noise_image(images)
return images, labels return images, labels
def get_glimpses(self, batch_images): def get_glimpses(self, batch_images):
......