A complete PyTorch implementation of a Variational Autoencoder with latent space exploration capabilities.
The model clusters similar digits together in 2D space:
Smooth transitions between digit types:
Original (top) vs Reconstructed (bottom):
A Variational Autoencoder is a generative model that:
- Encodes input data into a latent probability distribution
- Samples from this distribution using the reparameterization trick
- Decodes samples back to reconstructions
- Learns a structured, continuous latent space for generation
pip install torch torchvision matplotlib numpy tqdm --break-system-packagesSimply run:
python vae_mnist.pyThis will:
- Download MNIST dataset (if needed)
- Train the VAE for 20 epochs
- Generate visualization plots
- Save the trained model
Shows how the loss decreases during training.
Compares original MNIST digits with their reconstructions. Shows how well the model learned to encode and decode.
The most interesting visualization! Shows all test images projected into 2D latent space, colored by digit class.
- Similar digits cluster together
- You can see the topology of digit space
- Smooth transitions between digit types
A grid of generated images created by sampling different points in the latent space.
- Shows smooth interpolation between different digit types
- Demonstrates the continuous nature of the learned space
- Each point in the grid corresponds to a different latent coordinate
Input (784) β Encoder β [ΞΌ, log(ΟΒ²)] β Reparameterization β z (2D) β Decoder β Output (784)
Key Components:
- Encoder: Maps 784-dim images to 2D latent distribution parameters (mean & variance)
- Reparameterization Trick: Samples z = ΞΌ + Ο * Ξ΅ (where Ξ΅ ~ N(0,1))
- Decoder: Reconstructs 784-dim images from 2D latent vectors
- Loss Function: Reconstruction loss (BCE) + KL divergence
The 2D latent space makes it easy to:
- See digit clusters: Similar digits group together
- Generate new digits: Sample any point (x, y) in latent space
- Interpolate: Smoothly transition between digits
- Understand structure: See how the model organizes information
| Parameter | Value | Description |
|---|---|---|
| latent_dim | 2 | Dimension of latent space (2D for visualization) |
| hidden_dim | 400 | Size of hidden layers |
| batch_size | 128 | Training batch size |
| learning_rate | 1e-3 | Adam optimizer learning rate |
| epochs | 20 | Number of training epochs |
The KL term regularizes the latent space to follow a standard normal distribution N(0,1). This:
- Prevents "holes" in the latent space
- Ensures smooth interpolation
- Allows generation by sampling from N(0,1)
While 2D is limited for complex data, it's perfect for:
- Visualization and understanding
- Learning VAE concepts
- Seeing the structure emerge
For production, you'd typically use 10-100 dimensions.
Direct sampling from N(ΞΌ, ΟΒ²) isn't differentiable. Instead:
z = ΞΌ + Ο * Ξ΅, where Ξ΅ ~ N(0,1)
This makes the randomness independent of the parameters we're learning.
# In main(), modify:
'latent_dim': 10, # Higher dimensions for better reconstructionNote: Latent space visualization only works with 2D.
'epochs': 50, # More epochs for better convergence'learning_rate': 5e-4, # Lower for more stable trainingvae_mnist.py # Main implementation
README.md # This file
vae_model.pth # Saved model (after training)
training_curves.png # Loss plots
reconstructions.png # Original vs reconstructed
latent_space.png # 2D latent space visualization
manifold.png # Generated digit grid
data/ # MNIST dataset (auto-downloaded)
- Explore the manifold: Find regions that generate different digits
- Interpolate: Draw a path between two digits in latent space
- Increase dimensions: Try latent_dim=10 for better reconstructions
- Modify architecture: Add more layers or change activation functions
- Try different datasets: Fashion-MNIST, CIFAR-10 (needs CNN encoder/decoder)
- Original VAE paper: Auto-Encoding Variational Bayes
- Tutorial: Understanding VAEs
- Advanced: Ξ²-VAE, VQ-VAE, VAE-GAN hybrids
- GPU acceleration: Automatically uses CUDA if available
- Memory: Reduce batch_size if you run out of memory
- Convergence: Loss should steadily decrease. If it plateaus early, try lower learning rate
- Latent space: If clusters overlap heavily, try increasing latent_dim
Enjoy exploring the latent space! π



