I'm trying to implement a Conditional VAE for a regression problem, my dataset it's composed of images and a continuous value for each one. I'm starting from this example CVAE on mnist dataset that is used for a classification problem, so what changes I have to made in order to deal with continuous values?
I think that if I simply concatenate the img (my data) with the single regression value and give it as input for the encoder and the decoder I would not treat the problem properly or am I wrong?
My final goal is to give a regression value to the model and generate an image.
import warnings
import numpy as np
from keras.layers import Input, Dense, Lambda
from keras.layers.merge import concatenate as concat
from keras.models import Model
from keras import backend as K
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.callbacks import EarlyStopping
from keras.optimizers import Adam
from scipy.misc import imsave
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')
%pylab inline
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
n_pixels = np.prod(X_train.shape[1:])
X_train = X_train.reshape((len(X_train), n_pixels))
X_test = X_test.reshape((len(X_test), n_pixels))
y_train = to_categorical(Y_train)
y_test = to_categorical(Y_test)
n_z = 2 # latent space size
encoder_dim1 = 512 # dim of encoder hidden layer
decoder_dim = 512 # dim of decoder hidden layer
decoder_out_dim = 784 # dim of decoder output layer
activ = 'relu'
optim = Adam(lr=0.001)
n_x = X_train.shape[1]
n_y = y_train.shape[1]
n_epoch = 50
X = Input(shape=(n_x,))
label = Input(shape=(n_y,))
inputs = concat([X, label])
encoder_h = Dense(encoder_dim1, activation=activ)(inputs)
mu = Dense(n_z, activation='linear')(encoder_h)
l_sigma = Dense(n_z, activation='linear')(encoder_h)
def sample_z(args):
mu, l_sigma = args
eps = K.random_normal(shape=(m, n_z), mean=0., stddev=1.)
return mu + K.exp(l_sigma / 2) * eps
# Sampling latent space
z = Lambda(sample_z, output_shape = (n_z, ))([mu, l_sigma])
z = Lambda(sample_z, output_shape = (n_z, ))([mu, l_sigma])
# merge latent space with label
zc = concat([z, label])
decoder_hidden = Dense(decoder_dim, activation=activ)
decoder_out = Dense(decoder_out_dim, activation='sigmoid')
h_p = decoder_hidden(zc)
outputs = decoder_out(h_p)
def vae_loss(y_true, y_pred):
recon = K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
kl = 0.5 * K.sum(K.exp(l_sigma) + K.square(mu) - 1. - l_sigma, axis=-1)
return recon + kl
def KL_loss(y_true, y_pred):
return(0.5 * K.sum(K.exp(l_sigma) + K.square(mu) - 1. - l_sigma, axis=1))
def recon_loss(y_true, y_pred):
return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
cvae = Model([X, label], outputs)
encoder = Model([X, label], mu)
d_in = Input(shape=(n_z+n_y,))
d_h = decoder_hidden(d_in)
d_out = decoder_out(d_h)
decoder = Model(d_in, d_out)
cvae.compile(optimizer=optim, loss=vae_loss, metrics = [KL_loss, recon_loss])
# compile and fit
cvae_hist = cvae.fit([X_train, y_train], X_train, verbose = 1, batch_size=m, epochs=n_epoch,
validation_data = ([X_test, y_test], X_test),
callbacks = [EarlyStopping(patience = 5)])
plt.imshow(X_train[0].reshape(28, 28), cmap = plt.cm.gray), axis('off')
plt.show()
print(Y_train[0])
encoded_X0 = encoder.predict([X_train[0].reshape((1, 784)), y_train[0].reshape((1, 10))])
print(encoded_X0)
z_train = encoder.predict([X_train, y_train])
encodings= np.asarray(z_train)
encodings = encodings.reshape(X_train.shape[0], n_z)
plt.figure(figsize=(7, 7))
plt.scatter(encodings[:, 0], encodings[:, 1], c=Y_train, cmap=plt.cm.jet)
plt.show()
def construct_numvec(digit, z = None):
out = np.zeros((1, n_z + n_y))
out[:, digit + n_z] = 1.
if z is None:
return(out)
else:
for i in range(len(z)):
out[:,i] = z[i]
return(out)
sample_3 = construct_numvec(3)
print(sample_3)
plt.figure(figsize=(3, 3))
plt.imshow(decoder.predict(sample_3).reshape(28,28), cmap = plt.cm.gray), axis('off')
plt.show()
dig = 3
sides = 8
max_z = 1.5
img_it = 0
for i in range(0, sides):
z1 = (((i / (sides-1)) * max_z)*2) - max_z
for j in range(0, sides):
z2 = (((j / (sides-1)) * max_z)*2) - max_z
z_ = [z1, z2]
vec = construct_numvec(dig, z_)
decoded = decoder.predict(vec)
subplot(sides, sides, 1 + img_it)
img_it +=1
plt.imshow(decoded.reshape(28, 28), cmap = plt.cm.gray), axis('off')
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=.2)
plt.show()
dig = 2
sides = 8
max_z = 1.5
img_it = 0
for i in range(0, sides):
z1 = (((i / (sides-1)) * max_z)*2) - max_z
for j in range(0, sides):
z2 = (((j / (sides-1)) * max_z)*2) - max_z
z_ = [z1, z2]
vec = construct_numvec(dig, z_)
decoded = decoder.predict(vec)
subplot(sides, sides, 1 + img_it)
img_it +=1
plt.imshow(decoded.reshape(28, 28), cmap = plt.cm.gray), axis('off')
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=.2)
plt.show()