Skip to content

AI4PDLab/ProtRL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

88 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ProtRL: Reinforcement Learning for Protein Language Models

A Reinforcement Learning (RL) framework for autoregressive protein Language Models (pLMs). Currently we have implemented the following algorithms:

  • Weighted DPO
  • GRPO (bnpo, dr_grpo and grpo)

This is the repository for the paper Guiding Generative Protein Language Models with Reinforcement Learning.

Table of Content

About ProtRL

ProtRL allows you to:

Based on the GRPO implementation in Hugging Face’s TRL library, we have extended the trainer to support:

  1. Passing custom datasets at each iteration
  2. Weighted variant of DPO (not available in the standard Hugging Face trainer)

Quickstart Example

from src.utils import *
from src.pLM_weigtedDPO import weighted_DPO
from src.pLM_GRPO import pLM_GRPOTrainer
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(output_dir="ZymCTRL-wDPO", logging_steps=10)

trainer = pLM_wDPOTrainer( #pLM_rDPOTrainer, pLM_GRPOTrainer
    model= "AI4PD/ZymCTRL",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    processing_class=tokenizer,
)

trainer.train()

Usage

Trainer accepts the datasets in a HF standard format, for example:

{"prompt": "The sky is", "completion": " blue.", "advantage":10}

Offline training

Use train_exp.py, which expects a CSV file with columns:

  • prompt: prompt if any (in case of conditional generation)
  • sequence: pre-formatted protein sequences
  • advantage: numerical weight for each sequence
python train_exp.py --model_dir "AI4PD/ZymCTRL" --csv "training_data.csv"

the code will generate the dataset for you and train your model.

Online training

  1. We reccomend using the HF implementation of GRPO for straightforward rewards (e.g., sequence length, amino-acid ratios), use the standard GRPO trainer:
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("your_dataset")
split = dataset.train_test_split(test_size=0.80, seed=42, shuffle=True)

train_dataset = split['train']
eval_dataset   = split['test']

# Define the reward function, in this case
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

tokenizer = AutoTokenizer.from_pretrained("AI4PD/ZymCTRL")
tokenizer.padding_side = "left"
tokenizer.eos_token_id = 1
tokenizer.pad_token_id = 0

training_args = GRPOConfig(output_dir="ZymCTRL-GRPO", logging_steps=10)

trainer = GRPO_trainer(
    model= "AI4PD/ZymCTRL",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    processing_class=tokenizer,
)

trainer.train()
trainer.save_model() 

For complex pipelines—where you explicitly generate, save, and externally score sequences each iteration, you can use our trainers. This is ideal for scoring in CPU arrays before training on GPU:

from src.utils import *
from src.pLM_weigtedDPO import weighted_DPO
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(output_dir="ZymCTRL-GRPO", logging_steps=10)

trainer = weighted_DPO( #pLM_GRPOTrainer
    model= "AI4PD/ZymCTRL",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    processing_class=tokenizer,
)

trainer.train()

Note: The reward_funcs is ignored and can be set as a function always returning 0, see examples.

For the original DPO algorithm, we recommend the Hugging Face DPO Trainer.

Weighted DPO loss functions were adapted from the firsts described in Widatalla et al., 2024. You can find detailed explanations for each loss function and its changes in formulation in the Methods section of the paper.

Note: Weights and advantages are treated as "the higher, the better." If your scoring function is designed to be minimized, please multiply it by -1.

Installation

git clone https://github.com/AI4PDLab/ProtRL.git
cd ProtRL
pip install -r requirements.txt

Example

TinyLLaMA Length Reduction

The example directory includes tiny-llama directory, which demonstrates decreasing sequence length to 50 amino acids using a TinyLLaMA model that can be run locally on a single GPU.

cd example/GRPO
bash ProtRL-local.sh

This generates a TinyLLaMA model, runs RL training, and plots length reduction over iterations.

Carbonic Anhydrase Fold in ZymCTRL

We also provide a more complex example in example/ZymCTRL-fold, where the fold of carbonic anhydrase is progressively adapted over RL iterations. In this case esm-fold is required and a GPU of 80GB.

Experiments

To reproduce the experiments of our paper, you can find all the scripts in the experiments folder. Given the size and computational needs of pLMs, each one of the experiments were executed in one H100 GPU, with differing times of execution. All the parameters and external data used in the experiments can be found in this repo. The .sh scripts can be executed from the same folder to conduct each experiment, they have been built to work on a SLURM based cluster, given the need of GPU-intensive computing. To reproduce the results run:

bash experiment_name.sh

or

sbatch experiment_name.sh

Replace experiment_name with the desired experiment script path. Each experiment will produce, fold and calculate statistics for each considered feature.

Notes

seq_gen.py in the main directory generates a fasta file with this format >fasta_name /t perplexity /t intrinsic_reward /n sequence

We discontinue ranked DPO as theoretically it will always be outperformed by weighted DPO

Troubleshooting

Please take a look at the documentation for more details on how to configure and run your experiments.

Feel free to contribute or raise issues if you encounter any problems! We are working to make it more accessible and detailed

Work in Progress

[ ] LoRa example

References

Citation

If you use ProtRL, please cite our preprint:

@misc{stocco2024guidinggenerativeproteinlanguage,
      title={Guiding Generative Protein Language Models with Reinforcement Learning}, 
      author={Filippo Stocco and Maria Artigues-Lleixa and Andrea Hunklinger and Talal Widatalla and Marc Guell and Noelia Ferruz},
      year={2024},
      eprint={2412.12979},
      archivePrefix={arXiv},
      primaryClass={q-bio.BM},
      url={https://arxiv.org/abs/2412.12979}, 
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •