A PyTorch-based Convolutional Neural Network (CNN) for MNIST handwritten digit classification, implementing custom architecture, data augmentation, and regularization techniques.
- About the Dataset
- Project Structure
- Model Architecture
- Key Features
- Execution Pipeline
- Setup & Installation
- Usage
- Contributing
The system uses the MNIST Dataset, which consists of:
- 60,000 training images (28×28 grayscale)
- 10,000 test images
- 10 classes (digits 0–9)
The dataset is automatically downloaded via torchvision.datasets. Images are normalized using the dataset's specific mean and standard deviation.
Digit-Recognition/
├── requirements.txt # Python dependencies
├── README.md # Project documentation
├── main.ipynb # Main notebook (Training, Evaluation, Testing)
├── CNN_best_model.pth # Saved model weights (generated during training)
└── data/ # Dataset storage (created automatically)
The model follows a VGG-style design pattern, stacking convolutional layers with batch normalization and GELU activations before pooling.
Network Flow:
+-------------------------+
| Input Image (1x28x28) |
+-------------------------+
|
v
+-------------------------------------------------------+
| [BLOCK 1: Feature Extraction] |
| |
| Conv2D(32, k=3) -> BatchNorm -> GELU |
| | |
| v |
| Conv2D(32, k=3) -> BatchNorm -> GELU |
| | |
| v |
| MaxPool2D(k=2) |
| | |
| v |
| Dropout2d(p=0.2) |
+-------------------------------------------------------+
|
v (Tensor Shape: 32x14x14)
|
+-------------------------------------------------------+
| [BLOCK 2: Deeper Patterns] |
| |
| Conv2D(64, k=3) -> BatchNorm -> GELU |
| | |
| v |
| Conv2D(64, k=3) -> BatchNorm -> GELU |
| | |
| v |
| MaxPool2D(k=2) |
| | |
| v |
| Dropout2d(p=0.25) |
+-------------------------------------------------------+
|
v (Tensor Shape: 64x7x7)
|
+-------------------------------------------------------+
| [BLOCK 3: Abstract Concepts] |
| (No pooling here to preserve spatial grid) |
| |
| Conv2D(128, k=3) -> BatchNorm -> GELU |
| | |
| v |
| Dropout2d(p=0.3) |
+-------------------------------------------------------+
|
v (Tensor Shape: 128x7x7)
|
+-------------------------------------------------------+
| [CLASSIFIER HEAD] |
| |
| Flatten (Input vector size: 128*7*7 = 6272) |
| | |
| v |
| Linear(6272->512) -> BN1d -> GELU -> Dropout(p=0.35) |
| | |
| v |
| Linear(in=512, out=10) |
+-------------------------------------------------------+
|
v
+-------------------------+
| Final Output (10 Logits)|
+-------------------------+
- Hardware Acceleration: Supports both CUDA (Linux/Windows) and MPS (macOS) for GPU acceleration, falling back to CPU if unavailable.
- Data Augmentation: Applies random rotations (±10°), translations (±10%), and scaling (90-110%) during training to improve generalization.
- Regularization: Uses staggered Dropout rates (0.2 to 0.35) and Label Smoothing (0.1) to prevent overfitting.
- Evaluation Setup: Includes vectorized accuracy metrics, confusion matrix generation, and visual error analysis.
- Configurable Data Loading: Centralized settings for batch_size, num_workers, and pin_memory allow users to optimize throughput for their specific hardware. (Note: Default settings are pre-tuned for Apple Silicon M3 Pro efficiency).
The code execution process is automated with the following logic:
- Model Check: The script checks for existing weights (
CNN_best_model.pth).- If found, training is skipped, and weights are loaded.
- If not found, the training loop begins.
- Optimization:
- Optimizer: Adam (
lr=1e-3) - Scheduler: StepLR (decays learning rate by 0.5 every 10 epochs)
- Early Stopping: Stops training if validation loss does not improve for 5 consecutive epochs.
- Optimizer: Adam (
- Testing: The best performing model (lowest validation loss) is loaded for final evaluation on the test set.
This project requires Python 3.10 or higher.
-
Create and activate a virtual environment:
python -m venv venv # Windows venv\Scripts\activate # macOS/Linux source venv/bin/activate
-
Install dependencies:
pip install -r requirements.txt
-
Install uv:
pip install uv
-
Create and activate a virtual environment:
uv venv --python=python3.10 source .venv/bin/activate -
Install dependencies:
uv pip install -r requirements.txt
-
Download Pre-trained Weights (Optional): If you prefer to skip training and use the pre-trained model immediately:
- Go to the Releases section of this repository.
- Download the
CNN_best_model.pthfile. - Place it in the root directory of the project (
Digit-Recognition/). - The script will automatically detect this file and load the weights instead of training.
-
Launch Jupyter Lab:
jupyter lab main.ipynb
-
Run the Pipeline: Execute the notebook cells. The main execution block handles the logic for training versus loading:
# Automatically handles Training vs Loading based on file existence run_pipeline(cnn_model, train_loader, valid_loader, test_loader, device)
-
View Outputs: The notebook will display:
- Loss and Accuracy plots.
- A Confusion Matrix heatmap.
- A grid of misclassified images.
I welcome contributions! Whether it's optimizing the CNN architecture, adding new augmentation techniques, or improving the visualization, here is how you can help.
-
Fork the repository
# Clone your fork git clone [https://github.com/yourusername/Digit-Recognition.git](https://github.com/yourusername/Digit-Recognition.git) cd Digit-Recognition
-
Create a feature branch
git checkout -b feature/new-architecture
-
Make your changes
- Notebooks: Ensure
main.ipynbruns sequentially without errors (Restart Kernel and Run All Cells). - Code Style: Follow standard Python conventions (PEP 8).
- Validation: If you change the model, please run the full pipeline and include the new test accuracy in your PR description.
- Notebooks: Ensure
-
Commit your changes
git commit -m "Add elastic distortion to data augmentation" -
Push to your fork
git push origin feature/new-architecture
-
Open a Pull Request
- Provide a clear description of what you changed.
- Crucial: If your changes affect model performance, attach a screenshot of the new Loss/Accuracy plots or Confusion Matrix.
- Reference any related issues (e.g.,
Closes #42).
- Use GitHub Issues with the "bug" label.
- Include the specific cell where the error occurred.
- Provide your environment details (OS, PyTorch version, GPU/CPU).
- Use GitHub Issues with the "enhancement" label.
- Describe the proposed improvement (e.g., "Implement Quantization for mobile deployment").
- Explain the potential benefit to the project.
Star the repository if you like it. 🌟!