Skip to content

Variational Autoencoders trained on MNIST Dataset using PyTorch

License

Notifications You must be signed in to change notification settings

ac-alpha/VAEs-using-Pytorch

Repository files navigation

Variational Auto-Encoders using Pytorch

Variational Auto-Encoder Implementation trained on MNIST Dataset in Pytorch

Introduction

Variational AutoEncoders are a class of Generative Models which are used to deal with models of distributions P(X), defined over datapoints X in some potentially high-dimensional space X. We get examples X distributed according to some unknown distribution Pgt(X), and our goal is to learn a model P which we can sample from, such that P is as similar as possible to Pgt.

First we map our original image X to a latent variable z using some distribution Q(z|X) . Then we pass that value of z to the distribution P(X|z) to get an image as close to original image X. Before we can say that our model is representative of our dataset, we need to make sure that for every datapoint X in the dataset, there is one (or many) settings of the latent variables which causes the model to generate something very similar to X.

We wish to optimize parameters such that we can sample z from P(z) and, with high probability, P(X|z) will be like the X’s in our dataset. For most z, P(X|z) will be nearly zero, and hence contribute almost nothing to our estimate of P(X). The key idea behind the variational autoencoder is to attempt to sample values of z that are likely to have produced X, and compute P(X) just from those.

Usually the distribution Q(z|X) is taken to be gaussian. After taking everything into account, the final loss function comes out to be a sum of KL Divergence of Q(z|X) and P(z), i.e. D(Q(z|X)||P(z)); and cross entropy loss between original X and reconstructed X. For more details on this, see the paper and tutorial

Pipeline

This image has been taken from this tutorial.

Dataset

For all the experiments, I have used the MNIST dataset loaded using the DataLoader present in pytorch only.

Code

Code is well documented in the following files :-

  1. VAE Vanilla - Simple VAE using 20 latent variables trained on a fully connected network.
  2. VAE_two_latent_variables - Fully Connected network with only 2 latent variables.
  3. VAE_CNN - VAE using Convolution Layers .
  4. VAE_without_KLD_Loss - VAE trained using only Cross Entropy Loss and only 2 latent variables.
  5. VAE_without_Cross_Entropy_Loss - VAE trained using only KL Divergence Loss and only 2 latent variables.

Experiments and Visualisations :-

Experiments 1-3 are available on this notebook; 4-5 on this notebook; 6-7 on this notebook and 8 on this notebook.

  1. When we sample epsilon values at regular intervals, make z from them using mu and sigma of some image, then the reconstructed images are all similar to the original image but having some minute differences.

  1. When we change only one dimension of z, and then create samples, then they look very similar to the original image.

  1. Taking a fixed value of epsilon, when we make 2 z values z1 and z2 (using mu and sigma of 2 different images) using that epsilon then the values in between them create samples which are mix of the 2 digits.

  1. In 2 latent variable model, when we sample z values at specific intervals and pass them into the decoder, then we can see the transition between the digits.

  1. In 2 latent variable model, the scatter plot observed has the points belonging to one digit very close to that belonging to other digits and we can transition between them without sampling much of noise.

  1. If we perform experiment similar to experiment 1 in the VAE_without_KLD_Loss model, then the reconstructed images look similar and are actually having same value.

  2. In the VAE_without_KLD_Loss model, the scatter plot contains blobs of different digits seperated by random noise.

  1. In the VAE_without_Cross_Entropy_Loss model, the scatter plot contains blobs of different digits randomly mixed together but they tend to be closer to each other.

  1. As we increase the number of latent variables, we get more number of z's which produce noise upon passing through the decoder network.

Contributing

Suggestions to the repository are more than welcome. To open a pull request, click here. No specific naming criteria of variables is necessary as long as the name is explanatory.

License

See License

Releases

No releases published

Packages

No packages published