Skip to content

Variational Autoencoder

Fishest edited this page Mar 26, 2017 · 19 revisions

Introduction

Briefly speaking, Variational Autoencoder is the neural network that can encode the image into a vector in the latent space of z real numbers. The vector is supposed to be a random sample drawn from a z-dimensional normal distribution. Then from the encoded vector representation, the Decoder network can then decode the vector and obtain the original image back. The reason why the latent space is z-dimensional normal distribution is that we can then draw random samples from the distribution and fetch them into Decoder network. Then we can obtain brand new images that are not even in the dataset we trained on!!!

If you haven't learned anything about VAE, here is a good place to find explanation. The next sessions will assume the basic knowledge about VAE.

Most Basic VAE network

The most basic VAE network links the images, each flattened from a image matrix into a single vector of length dimdim. Then the input, of size [batch_size, dimdim] is fed into 3 layers of fully connected layers. Then the last fully connected layers produce a mean and standard deviation vector in the latent space.

Mnist dataset

Mnist dataset has images coded in 0 and 1, namely white and black. Let's try to run it with the basic network with z = 100. This is supposed to be a overkill because the image dimension is in 28*28 = 784, and we are merely compressing the data by a factor of 7.

The network becomes stable after several epochs:

('Epoch:', '0000', 'cost=', '7667393.486328125')
('Epoch:', '0001', 'cost=', '4704399.156250000')
('Epoch:', '0002', 'cost=', '4212229.031250000')
('Epoch:', '0003', 'cost=', '4010008.686523438')
('Epoch:', '0004', 'cost=', '3880332.463867188')
('Epoch:', '0005', 'cost=', '3792471.577148438')
('Epoch:', '0006', 'cost=', '3725320.927734375')
('Epoch:', '0007', 'cost=', '3677255.272460938')
('Epoch:', '0008', 'cost=', '3638155.862304688')
('Epoch:', '0009', 'cost=', '3600829.377929688')
('Epoch:', '0010', 'cost=', '3575842.495117188')
('Epoch:', '0011', 'cost=', '3545074.799804688')
('Epoch:', '0012', 'cost=', '3524245.027832031')

Here is the result:

Looks pretty good, doesn't it? Now let's decrease the dimension of latent space: Here is when **z = 10**:

Here is when **z = 2**:

The observation here is that when z is as small as 2, the result is pretty decent. So the network basically learned how to represent the digits. The latent space is divided into several regions where each region represent a digit and the decoder can produce the corresponding one.

Cifar-10 dataset

Now let's improve the difficulty by using the Cifar dataset. Images in the Cifar dataset are in 32323, with height and width being 32 pixels and 3 color channels. Now a single image has size 3072 and each color channel is in 0-255. Not only does the size becomes bigger, the representation is much more sophisticated than the Mnist dataset. Let's try to train using the basic fully connected network and see how this goes.

TODO

  • Add the results of the cifar dataset trained on basic network
  • Add the results of the greyscale of cifar dataset.

Results of cifar dataset on basic network:

Now we can see that the results are pretty blurry after long epochs of training, and it seems that the network has already converged and can't improve anymore. In fact, the network is extremely unstable. Training for more than 20 epochs can result in **nan** loss, which means that the network has converged to either producing black or white images for whatever input and any random encoded latent space vectors. Here are those cases:

Why this network works on Mnist dataset so well but can't work on Cifar dataset? Is it because the Cifar dataset is colored? To verify this, let's train the network using the greyscale of the original cifar images. And here's the result:

The same kind of **nan** error during train also occurred:

Still the results are pretty poor, so it doesn't have anything to do with whether the images are colored or not.

After thinking about it for a while, I think the major reason is that each color channel of each pixel in cifar dataset is from 0-255 with 256 distinct classes. When normalizing to the range [0,1], it can be viewed as a float or double. However the Mnist dataset only has 0 and 1, basically integers. The fully-connected network outputs float in range [0,1]. So in the case of Mnist dataset, its representation space is much much higher than required and therefore it can work so well even with latent space of dimension 2. However, in the case of cifar dataset, the output of the network and original image pixel data are both in the level of floats, so the network can only learn the contour of the image and show the blurry content.

The code for this session is here

Advanced VAE network

In the basic network, both the Encoder and Decoder only consist of fully-connected layers and it's hard to discover any underlying visual structures. The natural way to solve this problem is to use Convolutional Layers to encode and Deconvolutional Layers to decode. For the encoder, I didn't use any Pooling layers because Pooling layers can easily lose information of the original image which might be useful when trying to reconstruct in the decoder.

The code for the next sections are also in here

Cifar-10 dataset AGAIN

Now let's try this on the Cifar-10 dataset. The results are here:

Interestingly, using the advanced network doesn't actually solve the problem and results are even worse than that of basic network. I think the reason is that Cifar-10 dataset contains so many different kinds of images and the Encoder-Decoder network can not capture the structure behind all these different kinds.

CelebA dataset

Let's try CelebA dataset to verify whether my reasoning is correct. This dataset consists of only human faces and therefore has one consistent underlying structure. Here is the result:

The results are pretty decent but can not improve if I train for longer epochs. Notice that in the reconstructed faces, they are all smiling and without glasses (even if the original image has glasses). I think it is because smiley faces without glasses are the majority of the training data so the network hasn't learned how to encode and decode other facial expressions and glasses.

Reference

  1. Kvfran's blog post and github
  2. Tensorflow's tutorial on basic VAE

Clone this wiki locally