Commit 0bae69fa authored by Yin, Junqi's avatar Yin, Junqi
Browse files

port to tf.keras

parent 98a8a4f0
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
import numpy as np
import h5py 
import keras 
import tensorflow.keras as keras

class VAEGenerator(keras.utils.Sequence):

+19 −36
Original line number Diff line number Diff line
'''
Convolutional variational autoencoder in Keras;
'''

import numpy as np;
from keras.layers import Input, Dense, Lambda, Flatten, Reshape, Dropout;
from keras.layers import Convolution3D, Conv3DTranspose;
from keras.layers import AveragePooling3D, MaxPooling3D, UpSampling3D
from keras.models import Model;
from keras.optimizers import SGD, Adam, RMSprop, Adadelta;
from keras.callbacks import Callback, ModelCheckpoint;
from keras import backend as K;
from keras import objectives;
from tensorflow.keras.layers import Input, Dense, Lambda, Flatten, Reshape, Dropout;
from tensorflow.keras.layers import Convolution3D, Conv3DTranspose;
from tensorflow.keras.models import Model;
from tensorflow.keras.optimizers import SGD, Adam, RMSprop, Adadelta;
from tensorflow.keras.callbacks import Callback, ModelCheckpoint;
from tensorflow.compat.v1.keras import backend as K;
import tensorflow.keras.losses as objectives;
import warnings;

# save history from log;        
@@ -87,9 +87,9 @@ class conv_variational_autoencoder(object):
        
        # even shaped filters may cause problems in theano backend;
        even_filters = [f for pair in filter_shapes for f in pair if f % 2 == 0];
        if K.common.image_dim_ordering() == 'th' and len(even_filters) > 0:
        if K.image_data_format() == 'th' and len(even_filters) > 0:
            warnings.warn('Even shaped filters may cause problems in Theano backend')
        if K.common.image_dim_ordering() == 'channels_first' and len(even_filters) > 0:
        if K.image_data_format() == 'channels_first' and len(even_filters) > 0:
            warnings.warn('Even shaped filters may cause problems in Theano backend')
        
        self.eps_mean = eps_mean;
@@ -97,25 +97,20 @@ class conv_variational_autoencoder(object):
        self.image_size = image_size;
        self.lr = lr; 
        # define input layer;
        if K.common.image_dim_ordering() == 'th' or K.common.image_dim_ordering() == 'channels_first':
        if K.image_data_format() == 'th' or K.image_data_format() == 'channels_first':
            self.input = Input(shape=(channels,image_size[0],image_size[1],image_size[2]))
        else:
            self.input = Input(shape=(image_size[0],image_size[1],image_size[2],channels))
                 
        # padding for multiplier of 16
        #self.input = ZeroPadding3D(1)(self.input)
            self.input = Input(shape=(image_size[0],image_size[1],image_size[1],channels))
                    
        # define convolutional encoding layers;
        self.encode_conv = [];
        layer = Convolution3D(feature_maps[0],filter_shapes[0],padding='same',
                              activation=activation,strides=strides[0])(self.input);
        layer = AveragePooling3D((2,2,2),padding='same')(layer)
        self.encode_conv.append(layer);
        for i in range(1,conv_layers):
            layer = Convolution3D(feature_maps[i],filter_shapes[i],
                                  padding='same',activation=activation,
                                  strides=strides[i])(self.encode_conv[i-1]);
            layer = AveragePooling3D((2,2,2),padding='same')(layer)
            self.encode_conv.append(layer);
        
        # define dense encoding layers;
@@ -147,7 +142,7 @@ class conv_variational_autoencoder(object):
        
        # dummy model to get image size after encoding convolutions;
        self.decode_conv = [];
        if K.common.image_dim_ordering() == 'th' or K.common.image_dim_ordering() == 'channels_first':
        if K.image_data_format() == 'th' or K.image_data_format() == 'channels_first':
            dummy_input = np.ones((1,channels,image_size[0],image_size[1],image_size[2]))
        else:
            dummy_input = np.ones((1,image_size[0],image_size[1],image_size[2],channels))
@@ -162,35 +157,23 @@ class conv_variational_autoencoder(object):
        
        # define deconvolutional decoding layers;
        for i in range(1,conv_layers):
            if K.common.image_dim_ordering() == 'th' or K.common.image_dim_ordering() == 'channels_first':
            if K.image_data_format() == 'th' or K.image_data_format() == 'channels_first':
                dummy_input = np.ones((1,channels,image_size[0],image_size[1],image_size[2]))
            else:
                dummy_input = np.ones((1,image_size[0],image_size[1],image_size[2],channels))
            dummy = Model(self.input, self.encode_conv[-i-1]);
            conv_size = list(dummy.predict(dummy_input).shape);
            
            if K.common.image_dim_ordering() == 'th' or K.common.image_dim_ordering() == 'channels_first':
            if K.image_data_format() == 'th' or K.image_data_format() == 'channels_first':
                conv_size[1] = feature_maps[-i]
            else:
                conv_size[4] = feature_maps[-i]
            
            layer = Conv3DTranspose(feature_maps[-i-1],filter_shapes[-i],
                                    padding='same',activation=activation,
                                    strides=strides[-i])
            self.all_decoding.append(layer)
            layer = layer(self.decode_conv[i-1])
            layer = UpSampling3D((2,2,2))(layer)
            self.all_decoding.append(UpSampling3D((2,2,2)))
            self.decode_conv.append(layer);
        
        layer = Conv3DTranspose(feature_maps[-1],filter_shapes[-1],
                                    padding='same',activation=activation,
                                    strides=strides[-1])
        self.all_decoding.append(layer)
        layer = layer(self.decode_conv[-1])
        layer = UpSampling3D((2,2,2))(layer)
        self.all_decoding.append(UpSampling3D((2,2,2)))
        self.decode_conv.append(layer);
                                    strides=strides[-i]);
            self.all_decoding.append(layer);
            self.decode_conv.append(layer(self.decode_conv[i-1]));
        
        layer = Conv3DTranspose(channels,filter_shapes[0],padding='same',
                                activation='sigmoid',strides=strides[0]);
+10 −9
Original line number Diff line number Diff line
import os, sys, h5py
import horovod.keras as hvd
import horovod.tensorflow.keras as hvd
import tensorflow as tf
import keras
from keras import backend as K
from tensorflow.compat.v1.keras import backend as K
import math
import numpy as np
from dataloader import VAEGenerator
@@ -45,7 +44,7 @@ def VAE(input_shape, latent_dim=3, lr=0.001):
    conv_layers = 4
    feature_maps = [64,64,64,64]
    filter_shapes = [(3,3,3),(3,3,3),(3,3,3),(3,3,3)]
    strides = [(1,1,1),(1,1,1),(1,1,1),(1,1,1)]
    strides = [(2,2,2),(2,2,2),(2,2,2),(2,2,2)]
    dense_layers = 1
    dense_neurons = [64]
    dense_dropouts = [0]
@@ -66,15 +65,15 @@ def run_vae(cm_file_train, cm_file_val,
    gen_val = VAEGenerator(cm_file_val, hvd_size=hvd.size(), batch_size=batch_size, shuffle=True)
    input_shape = gen_train.get_shape()
    
    config = tf.ConfigProto()
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    K.set_session(tf.Session(config=config))
    config.gpu_options.visible_device_list = '0' #str(hvd.local_rank())
    K.set_session(tf.compat.v1.Session(config=config))

    #epochs = int(math.ceil(epochs / hvd.size()))

    #vae = VAE(input_shape[1:], hyper_dim, lr=0.001*hvd.size()) 
    vae = VAE(input_shape[1:], hyper_dim, lr=0.00005) 
    vae = VAE(input_shape[1:], hyper_dim, lr=0.0005) 
    vae.optimizer = hvd.DistributedOptimizer(vae.optimizer)
    vae.model.compile(optimizer=vae.optimizer, loss=vae._vae_loss)

@@ -104,6 +103,8 @@ def run_vae(cm_file_train, cm_file_val,
               batch_size=batch_size, epochs=epochs, 
               initial_epoch=resume_from_epoch, callbacks=callbacks) 
    if hvd.rank() == 0:   
        vae.embedder.save("encoder", save_format='tf')
        vae.generator.save("decoder", save_format='tf')
        vae.model.save_weights(model_weight.format(epoch=epochs))
        vae.save(model_file.format(epoch=epochs))
        losses = {'loss':[], 'val_loss':[]}