Molecular active learning with JAX - a lightweight framework for active learning in molecular property prediction.
- 🧠 Graph neural networks implemented in JAX/Flax for molecular representation learning
- 🔄 Complete active learning workflow for molecular property prediction
- 🎯 Multiple acquisition functions for diverse exploration strategies
git clone https://github.com/HFooladi/molax
cd molax
pip install -r requirements.txtRequired dependencies:
jax
flax
optax
rdkit
pandas
numpy
Basic usage with SMILES data:
import flax.nnx as nnx
import optax
from molax.utils.data import MolecularDataset
from molax.models.gcn import UncertaintyGCN, UncertaintyGCNConfig
# Load your data
dataset = MolecularDataset('datasets/molecules.csv')
# Split dataset
train_data, test_data = dataset.split_train_test(test_size=0.2)
# Initialize model
config = UncertaintyGCNConfig(
in_features=train_data.graphs[0][0].shape[1],
hidden_features=[64, 64],
out_features=1,
dropout_rate=0.1,
)
model = UncertaintyGCN(config)
# Initialize optimizer
model_and_opt = nnx.ModelAndOptimizer(model, optax.adam(1e-3))
# Run active learning loop
# See examples/simple_active_learning.py for complete implementation- Graph neural networks implemented in JAX/Flax
- Uncertainty estimation via MC dropout
- Multiple acquisition functions
- Efficient batch selection
- RDKit-based molecular processing
Check examples/simple_active_learning.py for a complete active learning pipeline with uncertainty-based acquisition.
For uncertainty quantification demonstration, see examples/uncertainty_gcn_demo.py.
- Fork the repository
- Create your feature branch (
git checkout -b feature/amazing-feature) - Commit changes (
git commit -m 'Add amazing feature') - Push to branch (
git push origin feature/amazing-feature) - Open a Pull Request
@software{molax2025,
title={molax: Molecular Active Learning with JAX},
author={Hosein Fooladi},
year={2025},
url={https://github.com/hfooladi/molax},
description={A lightweight framework for active learning in molecular property prediction}
}MIT License
- This project builds upon the excellent JAX, Flax, and RDKit libraries.
- Thanks to all contributors who have helped improve this project.