This repository contains code for Predicting Functional Brain Connectivity with Context-Aware Deep Neural Networks published at the Thirty-ninth Annual Conference on Neural Information Processing Systems (NeurIPS). Experimental code and modeling approaches are implemented to address the fundamental problem of predicting region-to-region human brain functional connectivity from gene expression and spatial information.
Core repository functionalities are detailed below:
/simsim.py→ simulation class defining features, targets, cross-validation, and run trackingnull.py→ null spin brain map generation and evaluation
/datadata_load.py→ load datasets and auxiliary datadata_utils.py→ data utilities, core RegionPairDataset class used for training, and target augmentation helperscv_split.py→ train-test split and cross-validation classesBHA2/,HCP/,UKBB/→ subset of population average connectomes and parcellated transcriptomesenigma/→ subsetted gene lists with known biological functions, null spin test indices and error metrics
/modelsrules_based.py,bilinear.py,pls.py,dynamic_mlp.py→ baseline methodssmt.py→ FlashAttention based MHSA transformer architectures including the Spatiomolecular Transformer (SMT)smt_advanced.py→ SMT variants, incorporating auxiliary information and pretrained embeddingssmt_cross.py→ compressed mixed-attention architectures (in development)/configs/→ hyperparameter sweep configs for each model/saved_models/→ pretrained SMT models for each datasettrain_val.py→ global training/validation loop/metrics/eval.py→ evaluation class with 32+ prediction metrics
/notebooksNeurIPS/→ Jupyter notebooks for analysis and figure creation
The main experimental functionality of GeneEx2Conn is the sim class. The sim class enables users to run detailed experiments depending on the underlying research question. Below is a basic example of a sim run, which can be triggered within a notebook or by command line.
single_sim_run(
dataset='HCP',
parcellation='S456',
omit_subcortical=False,
hemisphere='both',
feature_type=[{'transcriptome': None}], # input features
gene_list='0.2', # genes to retain for transcriptome
connectome_target='FC', # connectivity target
cv_type='spatial', # train-test split
random_seed=42,
model_type='shared_transformer', # model selection
)
See sim/sim_run.py for further details. Note: All Jupyter notebooks must be run from the root directory.
Careful train-test splits and null testing are critical to account for spatial autocorrelation in population average brain maps. /data/cv_split.py implements 4 train-test split styles: random, spatial (based on euclidean coordinates), community (based on Louvain communities), schaefer (based on Schaefer parcellation functional subnetworks). Random and spatial splits are visualized below for random seed=42.
All models in this repository are implemented in PyTorch. Models generally follow the form Y_i,j = decode(enc(x_i), enc(x_j)).
- Encoders: PCA, PLS, low-rank projection, autoencoder, 1D convolution, transformer
- Decoders: bilinear layer, inner product, MLP task head
The proposed transformer-based architecture in Predicting Functional Brain Connectivity with Context-Aware Deep Neural Networks is the Spatiomolecular Transformer (SMT).
smt.pyimplements the standard version of the SMT and SMT w/ [CLS] featuring TSS-aware tokenization, a spatial context CLS token, FlashAttention multi-head self-attention (MHSA), and an MLP task head.smt_advanced.pyincludes extensions to the base SMT by performing MHSA over encoded versions of the gene expression vectors, such as PCA, PLS, autoencoded embeddings. Attention pooling is used in most of these models to compress embeddings from the transformer block.smt_cross.pyimplements compressed versions of the SMT operating at single gene resolution with a smaller inputted gene list. These models are more NLP style, seeking to learn a grammar over a learned vocabulary of select genes. They benefit from fewer parameters and may use region-to-region cross-attention in the encoding phase.
Models are optimized to minimize the mean-squared error of predictions with the target population average connectome. Pretrained models can be found in models/saved_models. An example of how to load a pretrained model can be found in /notebooks/NeurIPS/NeurIPS_Fig5_embeddings.ipynb. See the save_model argument in single_sim_run() for saving a new model.
- Gene expression data: The Allen Human Brain Atlas represents the most spatially resolved human gene expression dataset to date. Raw data is available here. This repo relies heavily on the abagen package for AHBA preprocessing including normalizing, aggregating, and interpolating raw data into desired parcellations. Due to the size of the gene expression matrices, a sample csv file used for model training is made available for the coarsest parcellation resolution in
/data/BHA2/iPA_183. - Neuroimaging data: Models can be fit to connectomes from several open source datasets. MPI-LEMON is pubicly accessible here. Access to HCP can be requested here. Access to UKBB can be requested here. Population average connectomes are made available in
/data. For access to underlying individualized connectomes please reach out with the appropriate data use agreements.
GeneEx2Conn is a multi-component scientific research repository. Minimal requirements for our code are available in /env/GeneEx2Conn.yml (see https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html for environment setup). A100 or H100 NVIDIA GPU access is recommended for training SMT models. SMT training takes less than 2 hours for all 4 folds on A100 or H100 GPUs.
@inproceedings{
ratzan2025predicting,
title={Predicting Functional Brain Connectivity with Context-Aware Deep Neural Networks},
author={Alexander Ratzan and Sidharth Goel and Junhao Wen and Christos Davatzikos and Erdem Varol},
booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
year={2025},
url={https://openreview.net/forum?id=iQoZv77o3g}
}


