TensorFlow: Variational Autoencoder (VAE) for MNIST Digits

This post demonstrates the implementation of Tensorflow code for Variational Autoencoder (VAE) using a well-established example with MNIST digit data.


VAE in TensorFlow




Variational Autoencoder (VAE)


The Variational Autoencoder (VAE) is a generative model that allows us to learn a probabilistic representation of data.

The VAE architecture consists of an encoder and a decoder. The encoder maps input data to a probability distribution in a latent space, while the decoder generates data from samples drawn from the latent space.


The core concept of VAE is the latent space, which is represented by the mean and variance of a Gaussian distribution. The equations for VAE are as follows:

Latent Space Mean: \( \mu = f_{\mu}(x) \)
Latent Space Variance: \( \sigma^2 = f_{\sigma^2}(x) \)
Sample from Latent Space: \( z \sim \mathcal{N}(\mu, \sigma^2) \)

The loss function for VAE includes a reconstruction loss and a regularization term to encourage the latent space to be normally distributed.

\[\begin{align} \mathcal{L} = -\mathbb{E}_{q(z|x)}[\log p(x|z)] + \text{KL}(q(z|x) || p(z)) \end{align}\]
I'm omitting the derivation of the aforementioned loss function as there are abundant educational resources on Google. Numerous high-quality materials provide a better explanation than I can offer.

The reparameterization trick allows the training of generative models with stochastic elements while maintaining differentiability. It is crucial when working with continuous latent variables. \[\begin{align} z = \mu + \sigma \cdot \epsilon \end{align}\] here, \( \mu \) and \( \sigma \) are the mean and standard deviation of the distribution of the latent variable \( z \). \( \epsilon \) is sampled from a fixed distribution, typically a standard Gaussian distribution, \( N(0,1) \). .



Python Jupyter Notebook Code


A well-established example of VAE's application is with MNIST digits. The following code reads MNIST data and performs some preprocessing.

import numpy as np
import matplotlib.pyplot as plt
 
from keras.datasets import mnist
from keras.layers import Input, Lambda, Dense
from keras.models import Model
from keras import backend as K
from keras.utils import plot_model
from keras.losses import binary_crossentropy
 
# network parameters
rec_dim=784
input_shape = (rec_dim,)
int_dim = 512
lat_dim = 2
 
# Load the MNIST data
(x_tr, y_tr), (x_te, y_te) = mnist.load_data()
 
# normalize values of image pixels between 0 and 1f
x_tr = x_tr.astype('float32'/ 255.
x_te = x_te.astype('float32'/ 255.
 
# 28x28 2D matrix --> 784x1 1D vector
x_tr = x_tr.reshape((len(x_tr), np.prod(x_tr.shape[1:])))
x_te = x_te.reshape((len(x_te), np.prod(x_te.shape[1:])))
 
print(x_tr.shape, x_te.shape)
 
cs

The following code includes both the encoder and decoder. The encoder portion involves sampling latent factors using their mean and variance through the reparameterization trick.

#=======================
# Encoder
#=======================
# Z sampling function
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    
    # Reparameterization Trick
    # draw random sample ε from Gussian(=normal) distribution
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
 
# Input shape
inputs = Input(shape=input_shape)
enc_x  = Dense(int_dim, activation='relu')(inputs)
 
z_mean    = Dense(lat_dim)(enc_x)
z_log_var = Dense(lat_dim)(enc_x)
 
# sampling z
z_sampling = Lambda(sampling, (lat_dim,))([z_mean, z_log_var])
 
# encoder model has multi-output so a list is used
encoder = Model(inputs,[z_mean,z_log_var,z_sampling])
encoder.summary()
 
#=======================
# Decoder
#=======================
# Input of decoder is z
input_z = Input(shape=(lat_dim,))
dec_h   = Dense(int_dim, activation='relu')(input_z)
outputs = Dense(rec_dim, activation='sigmoid')(dec_h)
 
# z is the input and the reconstructed image is the output
decoder = Model(input_z, outputs)
decoder.summary()
 
cs

After constructing the VAE model, which encompasses both the encoder and decoder, the VAE loss, also referred to as the Evidence Lower Bound (ELBO), is calculated as the combination of the reconstruction loss and the Kullback-Leibler (KL) loss. Notably, in the case of beta-VAE, the KL loss is adjusted using a scaling factor, beta, to strike a balance between these two components.

#=======================
# VAE model
#=======================
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs)
 
#--------------------------------------------------
# VAE_loss = ELBO
#--------------------------------------------------
# (1)Reconstruct loss (Marginal_likelihood) : Cross-entropy 
rec_loss = binary_crossentropy(inputs,outputs)
rec_loss *= rec_dim
# (2) KL divergence(Latent_loss)
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = -0.5*K.sum(kl_loss, 1)
# (3) ELBO
vae_loss = K.mean(rec_loss + kl_loss)
#--------------------------------------------------
 
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
vae.summary()
 
history = vae.fit(x_tr, x_tr, shuffle=True
                  epochs=30, batch_size=64
                  validation_data=(x_te, x_te))
 
cs

After the completion of training, we can visualize the training and validation losses across epochs.

#=================================
# Training and validation losses
#=================================
def plt_hist(hist):
    plt.plot(hist.history['loss'])
    plt.plot(hist.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train''test'], loc='upper right')
    
plt_hist(history)
 
#===============================
# Raw and reconstructed images
#===============================
rec_x_te = vae.predict(x_te)
 
= 10  # how many digits we will display
plt.figure(figsize=(154))
for i in range(10):
    # original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_te[i].reshape(2828), 
               vmin=0, vmax=1, cmap="gray")
    plt.title("Input"+str(i))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
 
    # reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(rec_x_te[i].reshape(2828),
               vmin=0, vmax=1, cmap="gray")
    plt.title("Recon"+str(i))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
 
plt.show()
 
cs


Following the original code, a comparison is made between the original and reconstructed images.


By virtue of this straightforward example, we can improve our understanding of the underlying principles involved in the sampling of latent factors and the implementation of the VAE loss, commonly referred to as the Evidence Lower Bound (ELBO).


No comments:

Post a Comment