Skip to content

kartikpaliwal/SNodeViT

Repository files navigation

SNodeViT: Stochastic Neural ODE Vision Transformer

A toy implementation of a Vision Transformer that formulates attention dynamics as stochastic differential equations (SDEs) with fixed-rank KSVD decomposition.

🏗️ Architecture Overview

Standard Vision Transformer (ViT)

The standard ViT architecture consists of:

  • Patch Embedding: Images are divided into fixed-size patches (e.g., 16×16) and linearly projected
  • Positional Embeddings: Learnable position encodings added to patch embeddings
  • Transformer Blocks: Multi-head self-attention + MLP layers
  • Classification Head: Final linear layer for classification

Mathematical Formulation:

Attention(Q,K,V) = softmax(QK^T/√d_k)V

where Q, K, V are queries, keys, and values derived from input embeddings.

Fixed-Rank KSVD Attention

Instead of computing full attention matrices, SNodeViT uses a low-rank approximation:

Mathematical Formulation:

A_t = ∑_{i=1}^{r} λ_i(t) ψ_i(t) ⊗ φ_i(t)

Where:

  • r is the fixed rank (not scalable with sequence length)
  • ψ_i(t) and φ_i(t) are time-dependent basis functions
  • λ_i(t) are time-dependent scaling factors
  • denotes outer product

Complexity: O(N² × r) where N is sequence length and r is fixed rank.

Basic SDE Integration

The attention dynamics are modeled as a stochastic differential equation:

Mathematical Formulation:

dX_t = f(X_t, θ(t)) dt + G(X_t, σ(t)) dW_t

Where:

  • X_t is the state at time t
  • f(X_t, θ(t)) is the drift term (deterministic evolution)
  • G(X_t, σ(t)) dW_t is the diffusion term (stochastic noise)
  • θ(t) and σ(t) are time-dependent parameters

Integration Method: Euler-Maruyama scheme

X_{t+dt} = X_t + f(X_t, θ(t)) dt + G(X_t, σ(t)) √dt N(0,1)

📁 Project Structure

snodevit/
├── core/                    # Core model components
│   ├── attention.py        # Fixed-rank KSVD attention
│   ├── layers.py           # Continuous blocks and SDE integration
│   └── model.py            # Main SNodeViT architecture
├── training/                # Training utilities
│   ├── trainer.py          # Training loop with uncertainty
│   └── loss.py             # Loss functions
├── data/                    # Data loading and preprocessing
│   └── datasets.py         # Dataset managers
├── evaluation/              # Evaluation metrics
│   └── metrics.py          # Uncertainty and calibration metrics
├── configs/                 # Configuration files
│   ├── base.yaml           # Base configuration
│   ├── tiny.yaml           # Tiny model variant
│   ├── small.yaml          # Small model variant
│   ├── base_model.yaml     # Base model variant
│   └── large.yaml          # Large model variant
├── scripts/                 # Training and evaluation scripts
│   ├── train.py            # Main training script
│   └── evaluate.py         # Evaluation script
├── performance_charts/      # Generated performance charts
│   ├── training_performance.png
│   └── performance_summary.md
├── requirements.txt         # Python dependencies
├── setup.py                # Package setup
├── LICENSE                 # MIT License
└── README.md               # This file

🚀 Quick Start

Installation

git clone <your-repo>
cd snodevit
pip install -r requirements.txt

Basic Usage

from core import create_snodevit

# Create tiny model
model = create_snodevit(
    variant='tiny',
    num_classes=10,
    img_size=224
)

# Forward pass
x = torch.randn(1, 3, 224, 224)
output = model(x)  # [1, 10]

Training

# Train tiny model on CIFAR-10
python scripts/train.py \
    model.variant=tiny \
    data.name=cifar10 \
    training.batch_size=128 \
    training.learning_rate=2e-3

📊 Model Variants

Variant Embed Dim Depth Heads Basis Params
Tiny 96 6 3 8 ~1.2M
Small 192 8 4 16 ~4.8M
Base 384 12 6 24 ~19M
Large 768 16 12 32 ~76M

🔧 Key Components

1. StochasticPrimalAttention

  • Fixed-rank KSVD: Uses predefined number of basis functions
  • Time-dependent weights: Neural networks that generate time-varying parameters
  • Stochastic noise: Adds controlled randomness for uncertainty quantification

2. ContinuousBlock

  • Neural SDE: Combines attention and MLP in continuous-time formulation
  • Euler integration: Simple numerical integration scheme
  • Memory efficient: Gradient checkpointing support

3. TimeNet

  • Sinusoidal embedding: Converts time to high-dimensional representation
  • MLP processing: Generates time-dependent parameters
  • Stable initialization: Small weights for numerical stability

📈 Performance Results

Training Progress (Tiny Model on CIFAR-10)

  • Epoch 0: Loss: 2.05, Accuracy: 22.8%
  • Epoch 99: Loss: 0.46, Accuracy: 84.0%
  • Final: 79.4% validation accuracy

Training Performance Charts

Main Training Performance

Training Performance

Comprehensive Training Analysis

Realistic Training Analysis

Final Results (100 Epochs)

  • Training Accuracy: 84.0% (+61.2 percentage points improvement)
  • Validation Accuracy: 79.4%
  • Expected Calibration Error (ECE): 0.0279 (excellent uncertainty calibration)
  • Efficient: Gradient checkpointing enabled

⚠️ Limitations & Trade-offs

1. Fixed Rank Issue

  • Problem: Rank r doesn't scale with sequence length N
  • Impact: Not truly scalable for larger models
  • Trade-off: Memory efficiency vs. expressiveness

2. Standard Patch Embeddings

  • Reality: Uses same patch embedding as standard ViT
  • Innovation: Only in attention mechanism and SDE integration
  • Trade-off: Simplicity vs. novelty

3. Basic SDE Integration

  • Method: Simple Euler scheme
  • Stability: Basic drift clipping and noise control
  • Trade-off: Simplicity vs. accuracy

🧪 Experiments

Uncertainty Quantification

# Get predictions with uncertainty
mean_pred, uncertainty = model.forward_with_uncertainty(x, num_samples=5)
print(f"Prediction: {mean_pred.argmax()}")
print(f"Uncertainty: {uncertainty.mean():.3f}")

Attention Visualization

# Extract attention maps
attention_maps = model.get_attention_maps(x)
print(f"Number of attention layers: {len(attention_maps)}")

📚 Mathematical Background

1. Low-Rank Approximation

The attention matrix A is approximated as:

A ≈ UΣV^T

where U and V are orthogonal matrices and Σ is diagonal with r non-zero singular values.

2. Stochastic Differential Equations

SDEs model systems with both deterministic and random evolution:

dX_t = μ(X_t, t) dt + σ(X_t, t) dW_t

3. Euler-Maruyama Scheme

Numerical approximation for SDEs:

X_{n+1} = X_n + μ(X_n, t_n) Δt + σ(X_n, t_n) √Δt Z_n

where Z_n ~ N(0,1).

🎯 Use Cases

1. Research & Education

  • Understanding: How SDEs can model neural dynamics
  • Experimentation: Testing uncertainty quantification methods
  • Learning: Deep learning with continuous-time formulations

2. Small-Scale Applications

  • CIFAR-10/100: Image classification tasks
  • Prototyping: Quick experiments with novel architectures
  • Benchmarking: Comparing against standard ViT

3. Uncertainty-Aware Systems

  • Confidence estimation: Model prediction reliability
  • Risk assessment: Identifying uncertain predictions
  • Adaptive computation: Early stopping for easy examples

🔮 Future Improvements

1. Adaptive Rank Scaling

def compute_adaptive_rank(self, N, D, H):
    """Scale rank with input dimensions."""
    optimal_rank = int(math.sqrt(N * D) * 0.1)
    return min(optimal_rank, min(N, D, H))

2. Advanced SDE Solvers

  • Runge-Kutta methods: Higher-order integration schemes
  • Adaptive time stepping: Dynamic step size selection
  • Stochastic solvers: More sophisticated noise handling

3. Continuous Patch Embeddings

  • Time-varying kernels: Patches that evolve over time
  • Adaptive patch sizes: Dynamic receptive fields
  • Continuous convolutions: Smooth spatial transformations

📖 References

  1. Vision Transformer: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
  2. Neural ODEs: "Neural Ordinary Differential Equations"
  3. SDEs: "Stochastic Differential Equations: An Introduction with Applications"
  4. Low-Rank Attention: "Efficient Attention: Attention with Linear Complexities"

🤝 Contributing

This is a toy project for educational purposes. Feel free to:

  • Experiment with different architectures
  • Implement the suggested improvements
  • Share your findings and insights
  • Use as a starting point for research

📄 License

MIT License - feel free to use for research and education.


Note: This implementation focuses on clarity and educational value rather than production performance. The fixed-rank limitation and basic SDE integration make it suitable for understanding the concepts rather than achieving state-of-the-art results.

🚀 GitHub Preparation

Prerequisites

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA (optional, for GPU training)

Local Development Setup

# Clone the repository
git clone <your-github-repo-url>
cd snodevit

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

# Install in development mode
pip install -e .

Running Experiments

# Train tiny model on CIFAR-10
python scripts/train.py \
    model.variant=tiny \
    data.name=cifar10 \
    training.batch_size=128 \
    training.learning_rate=2e-3

# Evaluate trained model
python scripts/evaluate.py \
    --model-path checkpoints/best.pth \
    --config configs/tiny.yaml

Regenerating Performance Charts

# If you want to regenerate performance charts from new training logs
python generate_performance_charts.py

Contributing

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

Issues and Discussions

  • Bug Reports: Use the Issues tab for bug reports
  • Feature Requests: Open an issue for new features
  • Questions: Use Discussions for questions and help
  • Improvements: Submit PRs for code improvements

Happy Learning! 🎓

About

Stochastic Neural ODE Vision Transformer

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages