CosmoViT is a vision transformer style backbone with linear attention and optional persistent memory.
- Replaces softmax attention with a linear attention path using L1-normalized keys and global
$K^T V$ context. - Applies a power-law nonlinearity on ReLU Q/K to emphasize hub-like, high-activation token channels.
- Adds optional persistent memory with gated read/write to carry context across batches, with optional task gating.
- Uses bidirectional attention in each block by mixing forward and reversed token streams.
- Uses layer scaling in residual paths to stabilize deeper stacks.
- Defines the package boundary so Cosmo modules share a single namespace.
- Keeps initialization side-effect free, with no model logic executed on import.
- Preserves a stable import surface for experiments and scripts.
- Reserved for future package-level metadata or exports.
- Projects tokens into multi-head Q/K/V with a widened sparse dimension.
- Applies ReLU and power-law scaling to create hub-like sparsity in attention activations.
- Replaces softmax with L1-normalized keys and computes global context
$K^T V$ . - Supports persistent memory with gated read/write and optional task-conditioned gating.
- Blends attention output with the residual input via a learned gate.
- Each block mixes forward and reversed token streams for bidirectional context.
- Pre-norm attention and MLP branches use layer scaling for stable residual updates.
- The MLP branch adds nonlinear channel mixing after attention.
- Patch embedding converts images to tokens and adds learnable positional embeddings.
- Returns token features and the patch grid size for downstream heads.
- Uses the Cosmo backbone as a feature extractor for FashionMNIST images.
- Converts grayscale inputs to 3-channel tokens for patch embedding.
- Pools token features with a mean to form a global representation.
- Applies an MLP head to map pooled features to class logits.
- Trains with cross-entropy and AdamW, saving the best checkpoint by test loss.
- Validates attention outputs are finite and shape-consistent across configs.
- Confirms bidirectional blocks preserve token dimensionality.
- Checks backbone grid math and token counts for different image/patch sizes.
- Ensures persistent memory does not update in eval mode.
- Estimates a scaling exponent vs sequence length to verify near-linear behavior.
- Loads a trained checkpoint and runs inference on the FashionMNIST test set.
- Samples predictions and saves a denormalized image grid for quick inspection.
- Builds a confusion matrix and computes micro/macro precision, recall, and F1.
- Saves plots and a JSON report for reproducible evaluation.
- Returns metrics to make comparisons across checkpoints easy.
- datasets/: FashionMNIST downloads and cached data.
- models/: Saved checkpoints and evaluation outputs.
- Run model tests:
python test.py
- Train a classifier:
python classifier.py
- Evaluate a saved classifier:
python testClassifier.py