VAE in TensorFlow
Variational Autoencoder (VAE)
The Variational Autoencoder (VAE) is a generative model that allows us to learn a probabilistic representation of data.
The VAE architecture consists of an encoder and a decoder. The encoder maps input data to a probability distribution in a latent space, while the decoder generates data from samples drawn from the latent space.
The core concept of VAE is the latent space, which is represented by the mean and variance of a Gaussian distribution. The equations for VAE are as follows:
Latent Space Variance: \( \sigma^2 = f_{\sigma^2}(x) \)
Sample from Latent Space: \( z \sim \mathcal{N}(\mu, \sigma^2) \)
The loss function for VAE includes a reconstruction loss and a regularization term to encourage the latent space to be normally distributed.
\[\begin{align} \mathcal{L} = -\mathbb{E}_{q(z|x)}[\log p(x|z)] + \text{KL}(q(z|x) || p(z)) \end{align}\]
I'm omitting the derivation of the aforementioned loss function as there are abundant educational resources on Google. Numerous high-quality materials provide a better explanation than I can offer.
The reparameterization trick allows the training of generative models with stochastic elements while maintaining differentiability. It is crucial when working with continuous latent variables. \[\begin{align} z = \mu + \sigma \cdot \epsilon \end{align}\] here, \( \mu \) and \( \sigma \) are the mean and standard deviation of the distribution of the latent variable \( z \). \( \epsilon \) is sampled from a fixed distribution, typically a standard Gaussian distribution, \( N(0,1) \). .
Python Jupyter Notebook Code
A well-established example of VAE's application is with MNIST digits. The following code reads MNIST data and performs some preprocessing.
import numpy as np import matplotlib.pyplot as plt from keras.datasets import mnist from keras.layers import Input, Lambda, Dense from keras.models import Model from keras import backend as K from keras.utils import plot_model from keras.losses import binary_crossentropy # network parameters rec_dim=784 input_shape = (rec_dim,) int_dim = 512 lat_dim = 2 # Load the MNIST data (x_tr, y_tr), (x_te, y_te) = mnist.load_data() # normalize values of image pixels between 0 and 1f x_tr = x_tr.astype('float32') / 255. x_te = x_te.astype('float32') / 255. # 28x28 2D matrix --> 784x1 1D vector x_tr = x_tr.reshape((len(x_tr), np.prod(x_tr.shape[1:]))) x_te = x_te.reshape((len(x_te), np.prod(x_te.shape[1:]))) print(x_tr.shape, x_te.shape) | cs |
The following code includes both the encoder and decoder. The encoder portion involves sampling latent factors using their mean and variance through the reparameterization trick.
#======================= # Encoder #======================= # Z sampling function def sampling(args): z_mean, z_log_var = args batch = K.shape(z_mean)[0] dim = K.int_shape(z_mean)[1] # Reparameterization Trick # draw random sample ε from Gussian(=normal) distribution # by default, random_normal has mean = 0 and std = 1.0 epsilon = K.random_normal(shape=(batch, dim)) return z_mean + K.exp(0.5 * z_log_var) * epsilon # Input shape inputs = Input(shape=input_shape) enc_x = Dense(int_dim, activation='relu')(inputs) z_mean = Dense(lat_dim)(enc_x) z_log_var = Dense(lat_dim)(enc_x) # sampling z z_sampling = Lambda(sampling, (lat_dim,))([z_mean, z_log_var]) # encoder model has multi-output so a list is used encoder = Model(inputs,[z_mean,z_log_var,z_sampling]) encoder.summary() #======================= # Decoder #======================= # Input of decoder is z input_z = Input(shape=(lat_dim,)) dec_h = Dense(int_dim, activation='relu')(input_z) outputs = Dense(rec_dim, activation='sigmoid')(dec_h) # z is the input and the reconstructed image is the output decoder = Model(input_z, outputs) decoder.summary() | cs |
After constructing the VAE model, which encompasses both the encoder and decoder, the VAE loss, also referred to as the Evidence Lower Bound (ELBO), is calculated as the combination of the reconstruction loss and the Kullback-Leibler (KL) loss. Notably, in the case of beta-VAE, the KL loss is adjusted using a scaling factor, beta, to strike a balance between these two components.
#======================= # VAE model #======================= outputs = decoder(encoder(inputs)[2]) vae = Model(inputs, outputs) #-------------------------------------------------- # VAE_loss = ELBO #-------------------------------------------------- # (1)Reconstruct loss (Marginal_likelihood) : Cross-entropy rec_loss = binary_crossentropy(inputs,outputs) rec_loss *= rec_dim # (2) KL divergence(Latent_loss) kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var) kl_loss = -0.5*K.sum(kl_loss, 1) # (3) ELBO vae_loss = K.mean(rec_loss + kl_loss) #-------------------------------------------------- vae.add_loss(vae_loss) vae.compile(optimizer='adam') vae.summary() history = vae.fit(x_tr, x_tr, shuffle=True, epochs=30, batch_size=64, validation_data=(x_te, x_te)) | cs |
After the completion of training, we can visualize the training and validation losses across epochs.
#================================= # Training and validation losses #================================= def plt_hist(hist): plt.plot(hist.history['loss']) plt.plot(hist.history['val_loss']) plt.title('model loss') plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train', 'test'], loc='upper right') plt_hist(history) #=============================== # Raw and reconstructed images #=============================== rec_x_te = vae.predict(x_te) n = 10 # how many digits we will display plt.figure(figsize=(15, 4)) for i in range(10): # original ax = plt.subplot(2, n, i + 1) plt.imshow(x_te[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray") plt.title("Input"+str(i)) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) # reconstruction ax = plt.subplot(2, n, i + 1 + n) plt.imshow(rec_x_te[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray") plt.title("Recon"+str(i)) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.show() | cs |
Following the original code, a comparison is made between the original and reconstructed images.
By virtue of this straightforward example, we can improve our understanding of the underlying principles involved in the sampling of latent factors and the implementation of the VAE loss, commonly referred to as the Evidence Lower Bound (ELBO).
No comments:
Post a Comment