-
Notifications
You must be signed in to change notification settings - Fork 0
Variational Autoencoder
Briefly speaking, Variational Autoencoder is the neural network that can encode the image into a vector in the latent space of z real numbers. The vector is supposed to be a random sample drawn from a z-dimensional normal distribution. Then from the encoded vector representation, the Decoder network can then decode the vector and obtain the original image back. The reason why the latent space is z-dimensional normal distribution is that we can then draw random samples from the distribution and fetch them into Decoder network. Then we can obtain brand new images that are not even in the dataset we trained on!!!
If you haven't learned anything about VAE, here is a good place to find explanation. The next sessions will assume the basic knowledge about VAE.
The most basic VAE network links the images, each flattened from a image matrix into a single vector of length dimdim. Then the input, of size [batch_size, dimdim] is fed into 3 layers of fully connected layers. Then the last fully connected layers produce a mean and standard deviation vector in the latent space.
Mnist dataset has images coded in 0 and 1, namely white and black. Let's try to run it with the basic network with z = 100. This is supposed to be a overkill because the image dimension is in 28*28 = 784, and we are merely compressing the data by a factor of 7.
The network becomes stable after several epochs:
('Epoch:', '0000', 'cost=', '7667393.486328125')
('Epoch:', '0001', 'cost=', '4704399.156250000')
('Epoch:', '0002', 'cost=', '4212229.031250000')
('Epoch:', '0003', 'cost=', '4010008.686523438')
('Epoch:', '0004', 'cost=', '3880332.463867188')
('Epoch:', '0005', 'cost=', '3792471.577148438')
('Epoch:', '0006', 'cost=', '3725320.927734375')
('Epoch:', '0007', 'cost=', '3677255.272460938')
('Epoch:', '0008', 'cost=', '3638155.862304688')
('Epoch:', '0009', 'cost=', '3600829.377929688')
('Epoch:', '0010', 'cost=', '3575842.495117188')
('Epoch:', '0011', 'cost=', '3545074.799804688')
('Epoch:', '0012', 'cost=', '3524245.027832031')
Here is the result:
The observation here is that when z is as small as 2, the result is pretty decent. So the network basically learned how to represent the digits. The latent space is divided into several regions where each region represent a digit and the decoder can produce the corresponding one.
Now let's improve the difficulty by using the Cifar dataset. Images in the Cifar dataset are in 32323, with height and width being 32 pixels and 3 color channels. Now a single image has size 3072 and each color channel is in 0-255. Not only does the size becomes bigger, the representation is much more sophisticated than the Mnist dataset. Let's try to train using the basic fully connected network and see how this goes.
TODO
- Add the results of the cifar dataset trained on basic network
- Add the results of the greyscale of cifar dataset.
Results of cifar dataset on basic network:
Why this network works on Mnist dataset so well but can't work on Cifar dataset? Is it because the Cifar dataset is colored? To verify this, let's train the network using the greyscale of the original cifar images. And here's the result:
After thinking about it for a while, I think the major reason is that each color channel of each pixel in cifar dataset is from 0-255 with 256 distinct classes. When normalizing to the range [0,1], it can be viewed as a float or double. However the Mnist dataset only has 0 and 1, basically integers. The fully-connected network outputs float in range [0,1]. So in the case of Mnist dataset, its representation space is much much higher than required and therefore it can work so well even with latent space of dimension 2. However, in the case of cifar dataset, the output of the network and original image pixel data are both in the level of floats, so the network can only learn the contour of the image and show the blurry content.
The code for this session is here
In the basic network, both the Encoder and Decoder only consist of fully-connected layers and it's hard to discover any underlying visual structures. The natural way to solve this problem is to use Convolutional Layers to encode and Deconvolutional Layers to decode. For the encoder, I didn't use any Pooling layers because Pooling layers can easily lose information of the original image which might be useful when trying to reconstruct in the decoder.
The code for the next sections are also in here
Now let's try this on the Cifar-10 dataset. The results are here:
Interestingly, using the advanced network doesn't actually solve the problem and results are even worse than that of basic network. I think the reason is that Cifar-10 dataset contains so many different kinds of images and the Encoder-Decoder network can not capture the structure behind all these different kinds.
Let's try CelebA dataset to verify whether my reasoning is correct. This dataset consists of only human faces and therefore has one consistent underlying structure. Here is the result: