Skip to content

Nikhil-iitg27/CosmoViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CosmoViT

CosmoViT is a vision transformer style backbone with linear attention and optional persistent memory.

How CosmoViT Differs From Classic ViT

  • Replaces softmax attention with a linear attention path using L1-normalized keys and global $K^T V$ context.
  • Applies a power-law nonlinearity on ReLU Q/K to emphasize hub-like, high-activation token channels.
  • Adds optional persistent memory with gated read/write to carry context across batches, with optional task gating.
  • Uses bidirectional attention in each block by mixing forward and reversed token streams.
  • Uses layer scaling in residual paths to stabilize deeper stacks.

File Theory Breakdown

init.py

  • Defines the package boundary so Cosmo modules share a single namespace.
  • Keeps initialization side-effect free, with no model logic executed on import.
  • Preserves a stable import surface for experiments and scripts.
  • Reserved for future package-level metadata or exports.

attention.py

  • Projects tokens into multi-head Q/K/V with a widened sparse dimension.
  • Applies ReLU and power-law scaling to create hub-like sparsity in attention activations.
  • Replaces softmax with L1-normalized keys and computes global context $K^T V$.
  • Supports persistent memory with gated read/write and optional task-conditioned gating.
  • Blends attention output with the residual input via a learned gate.

backbone.py

  • Each block mixes forward and reversed token streams for bidirectional context.
  • Pre-norm attention and MLP branches use layer scaling for stable residual updates.
  • The MLP branch adds nonlinear channel mixing after attention.
  • Patch embedding converts images to tokens and adds learnable positional embeddings.
  • Returns token features and the patch grid size for downstream heads.

classifier.py

  • Uses the Cosmo backbone as a feature extractor for FashionMNIST images.
  • Converts grayscale inputs to 3-channel tokens for patch embedding.
  • Pools token features with a mean to form a global representation.
  • Applies an MLP head to map pooled features to class logits.
  • Trains with cross-entropy and AdamW, saving the best checkpoint by test loss.

test.py

  • Validates attention outputs are finite and shape-consistent across configs.
  • Confirms bidirectional blocks preserve token dimensionality.
  • Checks backbone grid math and token counts for different image/patch sizes.
  • Ensures persistent memory does not update in eval mode.
  • Estimates a scaling exponent vs sequence length to verify near-linear behavior.

testClassifier.py

  • Loads a trained checkpoint and runs inference on the FashionMNIST test set.
  • Samples predictions and saves a denormalized image grid for quick inspection.
  • Builds a confusion matrix and computes micro/macro precision, recall, and F1.
  • Saves plots and a JSON report for reproducible evaluation.
  • Returns metrics to make comparisons across checkpoints easy.

Folders

  • datasets/: FashionMNIST downloads and cached data.
  • models/: Saved checkpoints and evaluation outputs.

How To Run

  • Run model tests:
    • python test.py
  • Train a classifier:
    • python classifier.py
  • Evaluate a saved classifier:
    • python testClassifier.py

About

No description or website provided.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages