Skip to content

neuroinfolab/GeneEx2Conn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

459 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CC BY-NC-SA 4.0

GeneEx2Conn

This repository contains code for Predicting Functional Brain Connectivity with Context-Aware Deep Neural Networks published at the Thirty-ninth Annual Conference on Neural Information Processing Systems (NeurIPS). Experimental code and modeling approaches are implemented to address the fundamental problem of predicting region-to-region human brain functional connectivity from gene expression and spatial information.

Repository overview

Core repository functionalities are detailed below:

  • /sim
    • sim.py → simulation class defining features, targets, cross-validation, and run tracking
    • null.py → null spin brain map generation and evaluation
  • /data
    • data_load.py → load datasets and auxiliary data
    • data_utils.py → data utilities, core RegionPairDataset class used for training, and target augmentation helpers
    • cv_split.py → train-test split and cross-validation classes
    • BHA2/, HCP/, UKBB/ → subset of population average connectomes and parcellated transcriptomes
    • enigma/ → subsetted gene lists with known biological functions, null spin test indices and error metrics
  • /models
    • rules_based.py, bilinear.py, pls.py, dynamic_mlp.py → baseline methods
    • smt.py → FlashAttention based MHSA transformer architectures including the Spatiomolecular Transformer (SMT)
    • smt_advanced.py → SMT variants, incorporating auxiliary information and pretrained embeddings
    • smt_cross.py → compressed mixed-attention architectures (in development)
    • /configs/ → hyperparameter sweep configs for each model
    • /saved_models/ → pretrained SMT models for each dataset
    • train_val.py → global training/validation loop
    • /metrics/
      • eval.py → evaluation class with 32+ prediction metrics
  • /notebooks
    • NeurIPS/ → Jupyter notebooks for analysis and figure creation

Sim class

The main experimental functionality of GeneEx2Conn is the sim class. The sim class enables users to run detailed experiments depending on the underlying research question. Below is a basic example of a sim run, which can be triggered within a notebook or by command line.

single_sim_run(
              dataset='HCP',
              parcellation='S456',
              omit_subcortical=False,
              hemisphere='both',
              feature_type=[{'transcriptome': None}], # input features
              gene_list='0.2', # genes to retain for transcriptome
              connectome_target='FC', # connectivity target
              cv_type='spatial', # train-test split
              random_seed=42,
              model_type='shared_transformer', # model selection
              )

See sim/sim_run.py for further details. Note: All Jupyter notebooks must be run from the root directory.

Cross-validation

Careful train-test splits and null testing are critical to account for spatial autocorrelation in population average brain maps. /data/cv_split.py implements 4 train-test split styles: random, spatial (based on euclidean coordinates), community (based on Louvain communities), schaefer (based on Schaefer parcellation functional subnetworks). Random and spatial splits are visualized below for random seed=42.

Four-fold random split example Four-fold spatial split example
Training regions in gray; test regions in orange. Pairwise patterns are learned in the training set to reconstruct test-set pairwise edge strength.

Encoder-decoder architectures

All models in this repository are implemented in PyTorch. Models generally follow the form Y_i,j = decode(enc(x_i), enc(x_j)).

  • Encoders: PCA, PLS, low-rank projection, autoencoder, 1D convolution, transformer
  • Decoders: bilinear layer, inner product, MLP task head

The proposed transformer-based architecture in Predicting Functional Brain Connectivity with Context-Aware Deep Neural Networks is the Spatiomolecular Transformer (SMT).

  • smt.py implements the standard version of the SMT and SMT w/ [CLS] featuring TSS-aware tokenization, a spatial context CLS token, FlashAttention multi-head self-attention (MHSA), and an MLP task head.
  • smt_advanced.py includes extensions to the base SMT by performing MHSA over encoded versions of the gene expression vectors, such as PCA, PLS, autoencoded embeddings. Attention pooling is used in most of these models to compress embeddings from the transformer block.
  • smt_cross.py implements compressed versions of the SMT operating at single gene resolution with a smaller inputted gene list. These models are more NLP style, seeking to learn a grammar over a learned vocabulary of select genes. They benefit from fewer parameters and may use region-to-region cross-attention in the encoding phase.


Models are optimized to minimize the mean-squared error of predictions with the target population average connectome. Pretrained models can be found in models/saved_models. An example of how to load a pretrained model can be found in /notebooks/NeurIPS/NeurIPS_Fig5_embeddings.ipynb. See the save_model argument in single_sim_run() for saving a new model.

Datasets & access

  • Gene expression data: The Allen Human Brain Atlas represents the most spatially resolved human gene expression dataset to date. Raw data is available here. This repo relies heavily on the abagen package for AHBA preprocessing including normalizing, aggregating, and interpolating raw data into desired parcellations. Due to the size of the gene expression matrices, a sample csv file used for model training is made available for the coarsest parcellation resolution in /data/BHA2/iPA_183.
  • Neuroimaging data: Models can be fit to connectomes from several open source datasets. MPI-LEMON is pubicly accessible here. Access to HCP can be requested here. Access to UKBB can be requested here. Population average connectomes are made available in /data. For access to underlying individualized connectomes please reach out with the appropriate data use agreements.

Environment setup

GeneEx2Conn is a multi-component scientific research repository. Minimal requirements for our code are available in /env/GeneEx2Conn.yml (see https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html for environment setup). A100 or H100 NVIDIA GPU access is recommended for training SMT models. SMT training takes less than 2 hours for all 4 folds on A100 or H100 GPUs.

Citation

@inproceedings{
  ratzan2025predicting,
  title={Predicting Functional Brain Connectivity with Context-Aware Deep Neural Networks},
  author={Alexander Ratzan and Sidharth Goel and Junhao Wen and Christos Davatzikos and Erdem Varol},
  booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
  year={2025},
  url={https://openreview.net/forum?id=iQoZv77o3g}
}

About

Linking the human transcriptome and connectome

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages