Skip to content

Latest commit

 

History

History
132 lines (102 loc) · 6.06 KB

File metadata and controls

132 lines (102 loc) · 6.06 KB

Usage

This section covers the full training, inference, and evaluation workflows for the main GCP-VQVAE pipeline (multi-GPU supported via Accelerate), using HDF5 datasets generated by the data preprocessing pipeline.

Before you begin, prepare your dataset in .h5 format as described in Data Pipeline.

Training

Configure your training parameters in configs/config_vqvae.yaml and run:

Note:

# Set up accelerator configuration for multi-GPU training
accelerate config

# Start training with accelerate for multi-GPU support
accelerate launch train.py --config_path configs/config_vqvae.yaml

See the Accelerate documentation for more options and configurations.

Inference

Pretrained Models

Model Description Download Link
Large Full GCP-VQVAE model with best performance Download
Lite Lightweight version for faster inference Download

Setup Instructions:

  1. Download the zip file of the checkpoint
  2. Extract the checkpoint folder
  3. Set the trained_model_dir path in your config file (following ones) to point to the right checkpoint.

Multi-GPU with Hugging Face Accelerate:

  • The following scripts support multi-GPU via Accelerate: inference_encode.py, inference_embed.py, inference_decode.py, and evaluation.py.

Example (2 GPUs, bfloat16):

accelerate launch --multi_gpu --mixed_precision=bf16 --num_processes=2 evaluation.py

Or like in Training, configure Accelerate first:

accelerate config
accelerate launch evaluation.py

See the Accelerate documentation for more options and configurations.

All inference scripts consume .h5 inputs in the format defined in Data.

To extract the VQ codebook embeddings:

python codebook_extraction.py

Edit configs/inference_codebook_extraction_config.yaml to change paths and output filename.

To encode proteins into discrete VQ indices:

python inference_encode.py

Edit configs/inference_encode_config.yaml to change dataset paths, model, and output. Input datasets should be .h5 as in HDF5 format used by this repo.

To extract per-residue embeddings from the VQ layer:

python inference_embed.py

Edit configs/inference_embed_config.yaml to change dataset paths, model, and output HDF5. Input .h5 files must follow HDF5 format used by this repo.

To decode VQ indices back to 3D backbone structures:

python inference_decode.py

Edit configs/inference_decode_config.yaml to point to the indices CSV and adjust runtime. To write PDBs from decoded outputs, see Convert HDF5 → PDB.

Evaluation

To evaluate predictions and write TM-score/RMSD along with aligned PDBs:

python evaluation.py

Notes:

Example config template (configs/evaluation_config.yaml):

trained_model_dir: "/abs/path/to/trained_model"   # Folder containing checkpoint and saved YAMLs
checkpoint_path: "checkpoints/best_valid.pth"     # Relative to trained_model_dir
config_vqvae: "config_vqvae.yaml"                 # Names of saved training YAMLs
config_encoder: "config_gcpnet_encoder.yaml"
config_decoder: "config_geometric_decoder.yaml"

data_path: "/abs/path/to/evaluation/data.h5"      # HDF5 used for evaluation
output_base_dir: "evaluation_results"              # A timestamped subdir is created inside

batch_size: 8
shuffle: true
num_workers: 0
max_task_samples: 5000000                           # Optional cap
vq_indices_csv_filename: "vq_indices.csv"          # Also writes observed VQ indices
alignment_strategy: "kabsch"                       # "kabsch" or "no"
mixed_precision: "bf16"                            # "no", "fp16", "bf16", "fp8"

tqdm_progress_bar: true

Codebook Usage Statistics

Enable model.vqvae.vector_quantization.log_codebook_usage_statistics: true to log the following statistics to TensorBoard under codebook_usage_statistics/* during validation:

Metric TensorBoard tag Range
Unigram entropy (bits) codebook_usage_statistics/entropy_unigram_bits 0 -> log2(K) (12 bits for K=4096)
Unigram perplexity codebook_usage_statistics/perplexity_unigram 1 -> K
Bigram conditional entropy H2|1 codebook_usage_statistics/entropy_bigram_cond_bits 0 -> log2(K)
Bigram perplexity codebook_usage_statistics/perplexity_bigram 1 -> K
Trigram conditional entropy H3|21 codebook_usage_statistics/entropy_trigram_cond_bits 0 -> log2(K)
Trigram perplexity codebook_usage_statistics/perplexity_trigram 1 -> K
Delta H1 = H1 - H2|1 codebook_usage_statistics/delta_entropy_h1_h2 0 -> H1
Delta H2 = H1 - H3|21 codebook_usage_statistics/delta_entropy_h1_h3 0 -> H1
Extra conditional gain H2|1 - H3|21 codebook_usage_statistics/delta_entropy_conditional 0 -> H2|1
Mutual information (lag d) codebook_usage_statistics/mutual_info_lag{d} >=0, typically <=2 bits
Zipf slope codebook_usage_statistics/zipf_slope negative (approx -0.5 to -1.2)
Zipf R2 codebook_usage_statistics/zipf_r2 0 -> 1
Active codes codebook_usage_statistics/active_codes 0 -> K
Effective usage ratio (PPL1/active) codebook_usage_statistics/effective_usage_ratio 0 -> 1