This repository supports mechanistic interpretability experiments that reverse engineer what algorithm a neural network implements with causal abstraction.
This codebase follows a causal abstraction approach, where we hypothesize high-level causal models of how LLMs might solve tasks, and then locate where and how these abstract variables are represented in the model.
A causal model (causal_model.py) consists of:
- Variables: Concepts that might be represented in the neural network
- Values: Possible assignments to each variable
- Parent-Child Relationships: Directed relationships showing causal dependencies
- Mechanisms: Functions that compute a variable's value given its parents' values
Mechanistic interpretability aims to reverse-engineer what algorithm a neural network implements to achieve a particular capability. Causal abstraction is a theoretical framework that grounds out these notions; an algorithm is a causal model, a neural network is a causal model, and the notion of implementation is the relation of causal abstraction between two models. The algorithm is a high-level causal model and the neural network is a low-level causal model. When the high-level mechanisms are accurate simplifications of the low-level mechanisms, the algorithm is a causal abstraction of the low-level causal model.
What are the basic building blocks we should look at when trying to understand how AI systems work internally? This question is still being debated among researchers.
The causal abstraction framework remains agnostic to this question by allowing for building blocks of any shape and sizes that we call features. The features of a hidden vector in a neural network are accessed via an invertible featurizer, which might be an orthogonal matrix, the identity function, or an autoencoder. The neural network components are implemented in the neural/ directory with modular access to different model units.
We use interchange interventions to test if a variable in a high-level causal model aligns with specific features in the LLM. An interchange intervention replaces values from one input with values from another input, allowing us to isolate and test specific causal pathways.
The codebase implements five baseline approaches for feature construction and selection:
-
Full Vector: Uses the entire hidden vector without any transformations.
-
DAS (Distributed Alignment Search): Learns orthogonal directions with supervision from the causal model.
-
DBM (Desiderata-Based Masking): Learns binary masks over features using the causal model for supervision. Can be applied to select neurons (standard dimensions of hidden vectors), PCA components, or SAE features.
-
PCA (Principal Component Analysis): Uses unsupervised orthogonal directions derived from principal components. DBM can be used to align principal components with a high-level causal variable.
-
SAE (Sparse Autoencoder): Leverages pre-trained sparse autoencoders like GemmaScope and LlamaScope. DBM can be used to align SAE features with a high-level causal variable.
causal_model.py: Implementation of causal models with variables, values, parent-child relationships, and mechanisms for counterfactual generation and intervention mechanicscounterfactual_dataset.py: Dataset handling for counterfactual data generation and management
pipeline.py: Abstract base pipeline and LM pipeline classes for consistent interaction with different language modelsmodel_units.py: Base classes for accessing model components and features in transformer architecturesLM_units.py: Language model specific components for residual stream and attention head accessfeaturizers.py: Invertible feature space definitions with forward/inverse featurizer modules and intervention utilities
pyvene_core.py: Core utilities for creating, managing, and running intervention experiments using the pyvene libraryintervention_experiment.py: General intervention experiment frameworkfilter_experiment.py: Filtering and selection experimentsbenchmark_experiment.py: Benchmark experiment utilitiesconfig.py: Configuration management for experimentsexperiment_utils.py: Utility functions for experimentsvisualizations.py: Visualization tools for experiment resultsLM_experiments/: Language model specific experimentsattention_head_experiment.py: Experiments targeting attention head componentsresidual_stream_experiment.py: Experiments on residual stream representationsLM_utils.py: Utility functions for language model experiments
Task implementations organized by directory. Each task contains:
causal_models.py: Task-specific causal model definitionscounterfactuals.py: Counterfactual generation logictoken_positions.py: Token position specifications for interventions
Available tasks:
MCQA/: Multiple Choice Question Answering task implementation with positional causal model
Outdated tasks:
IOI/: Indirect Object Identification task (example from literature)entity_binding/: Entity binding taskgeneral_addition/: General addition task
Comprehensive test suite covering all core components with specialized tests for pyvene integration in test_pyvene_core/
git clone https://github.com/goodfire-ai/causalab.git
cd causalab
uv syncTo install for development, also run:
uv run pre-commit install # Set up git hooks- PyTorch: Deep learning framework for model operations
- pyvene: Library for causal interventions on neural networks
- transformers: Hugging Face library for language model access
- datasets: Hugging Face datasets library for data management
- scikit-learn: For dimensionality reduction techniques like PCA
- pytest: Testing framework
The best way to understand the codebase is through the onboarding tutorial notebooks in demos/onboarding_tutorial/:
- 01_define_MCQA_task.ipynb: Learn how to define causal models and counterfactual datasets
- 02_trace_residual_stream.ipynb: Trace information flow through language model layers
- 03_localize_with_patching.ipynb: Localize causal variables using activation patching
- 04_train_ DAS_and_DBM.ipynb: Train precise interventions with supervised methods
This is the baseline test set that the pre-merge check uses.
uv run pytest -m "not slow and not gpu"For full coverage, you may simply run:
uv run pytest