diff --git a/.gitignore b/.gitignore index 7e9e68e5..98daa2d7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ lightning_logs .DS_Store +**/.DS_Store .vscode/ .env @@ -26,4 +27,12 @@ wandb/ /store /inference /ProteinMPNN/ca_model_weights/*.pt -/ProteinMPNN/vanilla_model_weights/*.pt \ No newline at end of file +/ProteinMPNN/vanilla_model_weights/*.pt + +# Experiment logs +experiments.csv + +# Terraform +infra/terraform/.terraform/ +infra/terraform/terraform.tfstate* +infra/terraform/*.tfvars diff --git a/configs/inference_base.yaml b/configs/inference_base.yaml index 997e07ef..a1bf2c73 100644 --- a/configs/inference_base.yaml +++ b/configs/inference_base.yaml @@ -17,6 +17,7 @@ generation: args: nsteps: 400 self_cond: True + solver: euler # "euler", "heun" (ODE), "adaptive_heun" (ODE adaptive), or "stochastic_heun" (SDE 2nd order) # Guidance guidance_w: 1.0 # guidance model weights, 1.0 for w/o CFG and autoguidance, 0.0 for excluding the main model @@ -25,6 +26,15 @@ generation: save_trajectory_every: 0 # at which step interval to save trajectory snapshots of generation, 0 for no saving + # Adaptive step sizing (used when solver: adaptive_heun) + adaptive: + atol: 1.0e-3 + rtol: 1.0e-2 + dt_init: 0.01 + dt_min: 1.0e-4 + dt_max: 0.1 + safety_factor: 0.9 + # Model-specific sampling arguments model: bb_ca: diff --git a/configs/inference_ucond_notri_adaptive.yaml b/configs/inference_ucond_notri_adaptive.yaml new file mode 100644 index 00000000..7082face --- /dev/null +++ b/configs/inference_ucond_notri_adaptive.yaml @@ -0,0 +1,9 @@ +defaults: + - inference_ucond_notri_ode + - _self_ + +run_name_: laproteina_ucond_notri_adaptive +generation: + args: + solver: adaptive_heun + nsteps: 100 diff --git a/configs/inference_ucond_notri_heun.yaml b/configs/inference_ucond_notri_heun.yaml new file mode 100644 index 00000000..3dae7876 --- /dev/null +++ b/configs/inference_ucond_notri_heun.yaml @@ -0,0 +1,9 @@ +defaults: + - inference_ucond_notri_ode + - _self_ + +run_name_: laproteina_ucond_notri_heun +generation: + args: + solver: heun + nsteps: 100 diff --git a/configs/inference_ucond_notri_ode.yaml b/configs/inference_ucond_notri_ode.yaml new file mode 100644 index 00000000..4b9cc365 --- /dev/null +++ b/configs/inference_ucond_notri_ode.yaml @@ -0,0 +1,13 @@ +defaults: + - inference_ucond_notri + - _self_ + +run_name_: laproteina_ucond_notri_ode +generation: + model: + bb_ca: + simulation_step_params: + sampling_mode: vf + local_latents: + simulation_step_params: + sampling_mode: vf diff --git a/configs/inference_ucond_notri_stochastic_heun.yaml b/configs/inference_ucond_notri_stochastic_heun.yaml new file mode 100644 index 00000000..51add0ff --- /dev/null +++ b/configs/inference_ucond_notri_stochastic_heun.yaml @@ -0,0 +1,9 @@ +defaults: + - inference_ucond_notri + - _self_ + +run_name_: laproteina_ucond_notri_stochastic_heun +generation: + args: + solver: stochastic_heun + nsteps: 100 diff --git a/proteinfoundation/eval_trajectory.py b/proteinfoundation/eval_trajectory.py new file mode 100644 index 00000000..91aa3a09 --- /dev/null +++ b/proteinfoundation/eval_trajectory.py @@ -0,0 +1,265 @@ +"""Evaluate trajectory snapshots from protein generation. + +Loads ESMFold once and evaluates all trajectory steps + final outputs. +Produces a single CSV with a 'step' column and prints a summary table. + +Usage: + python proteinfoundation/eval_trajectory.py \ + --config_name inference_ucond_notri \ + --input_dir ./inference/inference_ucond_notri +""" + +import argparse +import glob +import os +import re +import sys +from collections import defaultdict + + +def _numsort_key(path: str): + """Sort key that orders numeric segments numerically, not alphabetically.""" + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", path)] + +root = os.path.abspath(".") +sys.path.insert(0, root) +# isort: split + +import pandas as pd +import torch +from loguru import logger + +from proteinfoundation.metrics.designability import ( + extract_seq_from_pdb, + pdb_name_from_path, + scRMSD, +) +from proteinfoundation.metrics.folding_models import load_esmfold + + +def discover_steps(input_dir: str) -> list[dict]: + """Discover trajectory step PDBs and the final output PDBs. + + Supports per-sample trajectory directories (new layout): + job_0_n_100_trajectory_0/job_0_step_0002_n_100_sample_0.pdb + and legacy per-step directories: + trajectory_step_0002/step_0002_sample_0.pdb + + Returns a list of dicts: [{"step": int, "pdb_paths": [str, ...], "label": str}, ...] + Sorted by step number. Final outputs get label "final". + """ + # Collect trajectory PDBs grouped by step number. + step_pdbs: dict[int, list[str]] = defaultdict(list) + + # New layout: per-sample dirs named *_trajectory_* + for d in sorted(glob.glob(os.path.join(input_dir, "*_trajectory_*")), key=_numsort_key): + if not os.path.isdir(d): + continue + for pdb in sorted(glob.glob(os.path.join(d, "*.pdb")), key=_numsort_key): + match = re.search(r"_step_(\d+)_", os.path.basename(pdb)) + if match: + step_pdbs[int(match.group(1))].append(pdb) + + # Legacy layout: per-step dirs named trajectory_step_NNNN + for d in sorted(glob.glob(os.path.join(input_dir, "trajectory_step_*")), key=_numsort_key): + if not os.path.isdir(d): + continue + match = re.search(r"trajectory_step_(\d+)$", d) + if not match: + continue + step = int(match.group(1)) + for pdb in sorted(glob.glob(os.path.join(d, "*.pdb")), key=_numsort_key): + step_pdbs[step].append(pdb) + + entries = [] + for step in sorted(step_pdbs): + entries.append({ + "step": step, + "pdb_paths": sorted(step_pdbs[step], key=_numsort_key), + "label": f"step_{step:04d}", + }) + + # Find final output PDBs (in job_*_id_* subdirectories, excluding trajectory dirs) + final_pdbs = [] + for dirpath, dirnames, filenames in os.walk(input_dir): + if "_trajectory_" in dirpath or "trajectory_step_" in dirpath: + continue + if "eval_tmp" in dirpath: + continue + for f in filenames: + if f.endswith(".pdb"): + final_pdbs.append(os.path.join(dirpath, f)) + + if final_pdbs: + max_traj_step = max(step_pdbs) if step_pdbs else 0 + final_step = max_traj_step if max_traj_step > 0 else 9999 + entries.append({"step": final_step, "pdb_paths": sorted(final_pdbs, key=_numsort_key), "label": "final"}) + + return entries + + +def evaluate_step( + step_info: dict, + preloaded_esmfold: tuple, + tmp_base: str, +) -> list[dict]: + """Evaluate all PDBs for a single trajectory step. + + Returns a list of per-sample result dicts. + """ + step = step_info["step"] + label = step_info["label"] + results = [] + + for pdb_path in step_info["pdb_paths"]: + name = pdb_name_from_path(pdb_path) + seq = extract_seq_from_pdb(pdb_path) + n = len(seq) + + # Create tmp dir for this sample's eval artifacts + tmp_dir = os.path.join(tmp_base, f"{label}_{name}") + os.makedirs(tmp_dir, exist_ok=True) + + row = { + "step": step, + "label": label, + "pdb_path": pdb_path, + "length": n, + "sequence": seq, + } + + try: + res = scRMSD( + pdb_file_path=pdb_path, + ret_min=False, + tmp_path=tmp_dir, + use_pdb_seq=False, + rmsd_modes=["ca"], + folding_models=["esmfold"], + preloaded_esmfold=preloaded_esmfold, + keep_outputs=True, + ) + # res["ca"]["esmfold"] is a list of RMSD values (one per ProteinMPNN seq) + rmsds = res["ca"]["esmfold"] + row["scRMSD_ca_min"] = min(rmsds) if rmsds else float("inf") + row["scRMSD_ca_mean"] = sum(rmsds) / len(rmsds) if rmsds else float("inf") + row["scRMSD_ca_all"] = rmsds + except Exception as e: + logger.error(f"Failed to evaluate {pdb_path}: {e}") + row["scRMSD_ca_min"] = float("inf") + row["scRMSD_ca_mean"] = float("inf") + row["scRMSD_ca_all"] = [] + + results.append(row) + return results + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate trajectory snapshots") + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Generation output dir containing trajectory_step_*/ subdirs and final PDBs", + ) + parser.add_argument( + "--output_csv", + type=str, + default=None, + help="Path for output CSV (default: /trajectory_eval.csv)", + ) + args = parser.parse_args() + + torch.set_float32_matmul_precision("high") + + input_dir = args.input_dir + if not os.path.isdir(input_dir): + raise ValueError(f"Input directory does not exist: {input_dir}") + + output_csv = args.output_csv or os.path.join(input_dir, "trajectory_eval.csv") + + # Discover steps + steps = discover_steps(input_dir) + if not steps: + logger.error(f"No PDB files found in {input_dir}") + sys.exit(1) + + total_pdbs = sum(len(s["pdb_paths"]) for s in steps) + logger.info(f"Found {len(steps)} steps with {total_pdbs} total PDBs") + for s in steps: + logger.info(f" {s['label']}: {len(s['pdb_paths'])} PDBs") + + # Load ESMFold once + logger.info("Loading ESMFold model (one-time)...") + esmfold_model, esmfold_tokenizer = load_esmfold() + preloaded_esmfold = (esmfold_model, esmfold_tokenizer) + logger.info("ESMFold loaded.") + + # Load existing results to skip already-evaluated PDBs (resume support) + existing_pdbs = set() + prior_results = [] + if os.path.exists(output_csv): + prior_df = pd.read_csv(output_csv) + existing_pdbs = set(prior_df["pdb_path"].tolist()) + prior_results = prior_df.to_dict("records") + logger.info(f"Resuming: found {len(existing_pdbs)} already-evaluated PDBs in {output_csv}") + + # Filter out already-evaluated PDBs from each step + for step_info in steps: + before = len(step_info["pdb_paths"]) + step_info["pdb_paths"] = [p for p in step_info["pdb_paths"] if p not in existing_pdbs] + skipped = before - len(step_info["pdb_paths"]) + if skipped: + logger.info(f" {step_info['label']}: skipping {skipped}/{before} already evaluated") + + remaining = sum(len(s["pdb_paths"]) for s in steps) + if remaining == 0: + logger.info("All PDBs already evaluated, nothing to do.") + else: + logger.info(f"{remaining} PDBs remaining to evaluate") + + # Evaluate each step + tmp_base = os.path.join(input_dir, "eval_tmp") + all_results = list(prior_results) + + for step_info in steps: + if not step_info["pdb_paths"]: + continue + logger.info(f"Evaluating {step_info['label']} ({len(step_info['pdb_paths'])} PDBs)...") + step_results = evaluate_step(step_info, preloaded_esmfold, tmp_base) + all_results.extend(step_results) + + # Save incrementally after each step so progress survives crashes + df_inc = pd.DataFrame(all_results) + df_inc.drop(columns=["scRMSD_ca_all"], errors="ignore").to_csv(output_csv, index=False) + + # Final save + df = pd.DataFrame(all_results) + df_csv = df.drop(columns=["scRMSD_ca_all"], errors="ignore") + df_csv.to_csv(output_csv, index=False) + logger.info(f"Results saved to {output_csv}") + + # Print summary table + print("\n" + "=" * 70) + print("TRAJECTORY EVALUATION SUMMARY") + print("=" * 70) + summary = df.groupby(["step", "label"]).agg( + n_samples=("scRMSD_ca_min", "count"), + mean_scRMSD=("scRMSD_ca_min", "mean"), + median_scRMSD=("scRMSD_ca_min", "median"), + designability_rate=("scRMSD_ca_min", lambda x: (x < 2.0).mean()), + ).reset_index() + + print(f"\n{'Step':>6} {'Label':<12} {'N':>4} {'Mean scRMSD':>12} {'Med scRMSD':>11} {'Design. Rate':>13}") + print("-" * 70) + for _, row in summary.iterrows(): + print( + f"{row['step']:>6} {row['label']:<12} {row['n_samples']:>4} " + f"{row['mean_scRMSD']:>12.3f} {row['median_scRMSD']:>11.3f} " + f"{row['designability_rate']:>12.1%}" + ) + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/proteinfoundation/eval_trajectory_fast.py b/proteinfoundation/eval_trajectory_fast.py new file mode 100644 index 00000000..8df3b18d --- /dev/null +++ b/proteinfoundation/eval_trajectory_fast.py @@ -0,0 +1,493 @@ +"""Fast trajectory evaluation using coordinate-only metrics (no neural networks). + +Computes per-step: + - CA RMSD vs final structure (convergence) + - CA RMSD vs previous step (rate of change) + - Radius of gyration + - Secondary structure content (helix/sheet fraction via CA distances) + - (optional) ProteinMPNN backbone perplexity (cheap designability proxy) + +Runs in seconds, not hours. Use for trajectory analysis; reserve full scRMSD +eval (eval_trajectory.py) for final structures only. + +Usage: + python proteinfoundation/eval_trajectory_fast.py \ + --input_dir ./inference/inference_ucond_notri/2026-02-14_205759 + + # With backbone perplexity (adds ~0.5s per PDB on CPU): + python proteinfoundation/eval_trajectory_fast.py \ + --input_dir ./inference/inference_ucond_notri/2026-02-14_205759 \ + --perplexity +""" + +import argparse +import glob +import os +import re +import sys +from collections import defaultdict + +import numpy as np + + +def _numsort_key(path: str): + """Sort key that orders numeric segments numerically, not alphabetically.""" + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", path)] + +root = os.path.abspath(".") +sys.path.insert(0, root) +# isort: split + +import pandas as pd +from loguru import logger + +from proteinfoundation.utils.pdb_utils import load_pdb + +# CA is atom index 1 in atom37 representation. +CA_IDX = 1 + + +def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray: + """Numerically stable softmax.""" + e = np.exp(x - np.max(x, axis=axis, keepdims=True)) + return e / np.sum(e, axis=axis, keepdims=True) + + +def extract_ca_coords(pdb_path: str) -> np.ndarray: + """Load a PDB and return CA coordinates as (n, 3) array.""" + prot = load_pdb(pdb_path) + return prot.atom_positions[:, CA_IDX, :] + + +def kabsch_rmsd(coords1: np.ndarray, coords2: np.ndarray) -> float: + """Compute CA RMSD after optimal superposition (Kabsch alignment). + + Both inputs should be (n, 3) arrays with matching residue count. + """ + assert coords1.shape == coords2.shape + # Center + c1 = coords1 - coords1.mean(axis=0) + c2 = coords2 - coords2.mean(axis=0) + # Kabsch + H = c1.T @ c2 + U, S, Vt = np.linalg.svd(H) + d = np.linalg.det(Vt.T @ U.T) + sign_matrix = np.diag([1, 1, np.sign(d)]) + R = Vt.T @ sign_matrix @ U.T + c1_aligned = c1 @ R.T + return float(np.sqrt(np.mean(np.sum((c1_aligned - c2) ** 2, axis=1)))) + + +def radius_of_gyration(coords: np.ndarray) -> float: + """Radius of gyration from CA coordinates.""" + center = coords.mean(axis=0) + return float(np.sqrt(np.mean(np.sum((coords - center) ** 2, axis=1)))) + + +def secondary_structure_fractions(coords: np.ndarray) -> dict: + """Estimate helix/sheet content from CA-CA distance patterns. + + Uses the characteristic CA-CA distances: + - Alpha helix: CA(i) to CA(i+3) ~ 5.1-5.5 A + - Beta sheet: CA(i) to CA(i+2) ~ 6.5-7.5 A (extended) + """ + n = len(coords) + if n < 4: + return {"helix_frac": 0.0, "sheet_frac": 0.0} + + helix_count = 0 + sheet_count = 0 + for i in range(n - 3): + d_i3 = np.linalg.norm(coords[i] - coords[i + 3]) + if 4.8 <= d_i3 <= 5.8: + helix_count += 1 + for i in range(n - 2): + d_i2 = np.linalg.norm(coords[i] - coords[i + 2]) + if 6.2 <= d_i2 <= 7.8: + sheet_count += 1 + + return { + "helix_frac": helix_count / max(n - 3, 1), + "sheet_frac": sheet_count / max(n - 2, 1), + } + + +def discover_samples(input_dir: str) -> dict[int, dict]: + """Discover per-sample trajectory directories and final PDBs. + + Returns {sample_idx: {"trajectory": [(step, path), ...], "final": path_or_None}} + """ + samples: dict[int, dict] = defaultdict(lambda: {"trajectory": [], "final": None}) + + # New layout: per-sample dirs *_trajectory_N + for d in sorted(glob.glob(os.path.join(input_dir, "*_trajectory_*")), key=_numsort_key): + if not os.path.isdir(d): + continue + dir_match = re.search(r"_trajectory_(\d+)$", d) + if not dir_match: + continue + sample_idx = int(dir_match.group(1)) + for pdb in sorted(glob.glob(os.path.join(d, "*.pdb")), key=_numsort_key): + step_match = re.search(r"_step_(\d+)_", os.path.basename(pdb)) + if step_match: + samples[sample_idx]["trajectory"].append( + (int(step_match.group(1)), pdb) + ) + + # Legacy layout: per-step dirs trajectory_step_NNNN + for d in sorted(glob.glob(os.path.join(input_dir, "trajectory_step_*")), key=_numsort_key): + if not os.path.isdir(d): + continue + step_match = re.search(r"trajectory_step_(\d+)$", d) + if not step_match: + continue + step = int(step_match.group(1)) + for pdb in sorted(glob.glob(os.path.join(d, "*.pdb")), key=_numsort_key): + # Extract sample index from filename + sample_match = re.search(r"sample_(\d+)", os.path.basename(pdb)) + sample_idx = int(sample_match.group(1)) if sample_match else 0 + samples[sample_idx]["trajectory"].append((step, pdb)) + + # Final PDBs (in job_*_n_*_id_* subdirectories) + final_pdbs = [] + for dirpath, _, filenames in os.walk(input_dir): + if "_trajectory_" in dirpath or "trajectory_step_" in dirpath: + continue + if "eval_tmp" in dirpath: + continue + for f in filenames: + if f.endswith(".pdb"): + final_pdbs.append(os.path.join(dirpath, f)) + + # Match finals to trajectory samples by protein length. + # Extract length from path (_n__) or by loading the PDB. + def _get_length(path: str) -> int: + m = re.search(r"_n_(\d+)[_./]", path) + if m: + return int(m.group(1)) + # Fallback: load PDB and count residues + return len(extract_ca_coords(path)) + + finals_by_length: dict[int, list[str]] = defaultdict(list) + for pdb in sorted(final_pdbs, key=_numsort_key): + finals_by_length[_get_length(pdb)].append(pdb) + + assigned_by_length: dict[int, int] = defaultdict(int) + for sample_idx in sorted(samples): + traj = samples[sample_idx]["trajectory"] + if not traj: + continue + length = _get_length(traj[0][1]) + idx = assigned_by_length[length] + if idx < len(finals_by_length.get(length, [])): + samples[sample_idx]["final"] = finals_by_length[length][idx] + assigned_by_length[length] += 1 + + # If no trajectory samples were matched to finals, assign by order + if not any(s["final"] for s in samples.values()): + for i, pdb in enumerate(sorted(final_pdbs, key=_numsort_key)): + samples[i]["final"] = pdb + + # Sort trajectories by step + for s in samples.values(): + s["trajectory"].sort(key=lambda x: x[0]) + + return dict(samples) + + +def evaluate_sample(sample_idx: int, sample_info: dict, pmpnn_model=None) -> list[dict]: + """Evaluate one sample's trajectory + final structure. + + Args: + sample_idx: Index of the sample. + sample_info: Dict with "trajectory" and "final" keys. + pmpnn_model: Pre-loaded ProteinMPNN model for backbone perplexity. + If None, perplexity is skipped. + """ + results = [] + trajectory = sample_info["trajectory"] + final_path = sample_info["final"] + + # Load final CA coords for RMSD-to-final + final_ca = None + if final_path: + final_ca = extract_ca_coords(final_path) + + prev_ca = None + all_paths = [(step, path) for step, path in trajectory] + if final_path: + max_step = max(s for s, _ in trajectory) if trajectory else 0 + all_paths.append((max_step, final_path)) + + for step, pdb_path in all_paths: + ca = extract_ca_coords(pdb_path) + is_final = pdb_path == final_path + label = "final" if is_final else f"step_{step:04d}" + + row = { + "sample": sample_idx, + "step": step, + "label": label, + "pdb_path": pdb_path, + "length": len(ca), + "radius_of_gyration": radius_of_gyration(ca), + } + + # RMSD vs final + if is_final: + row["rmsd_to_final"] = 0.0 + elif final_ca is not None: + try: + row["rmsd_to_final"] = kabsch_rmsd(ca, final_ca) + except Exception: + row["rmsd_to_final"] = float("nan") + else: + row["rmsd_to_final"] = float("nan") + + # RMSD vs previous step + if prev_ca is not None and len(ca) == len(prev_ca): + row["rmsd_to_prev"] = kabsch_rmsd(ca, prev_ca) + else: + row["rmsd_to_prev"] = float("nan") + + # Secondary structure + ss = secondary_structure_fractions(ca) + row["helix_frac"] = ss["helix_frac"] + row["sheet_frac"] = ss["sheet_frac"] + + # AA logit statistics (if saved alongside PDB) + logits_path = pdb_path.replace(".pdb", "_aa_logits.npy") + if os.path.exists(logits_path): + logits = np.load(logits_path) # [n, 20] + probs = _softmax(logits) # [n, 20] + # Entropy: how uncertain the AA prediction is (max = log(20) ≈ 3.0) + entropy = -np.sum(probs * np.log(probs + 1e-10), axis=-1) # [n] + row["aa_entropy_mean"] = float(np.mean(entropy)) + # Max probability: how confident the top AA pick is + row["aa_max_prob_mean"] = float(np.mean(np.max(probs, axis=-1))) + else: + row["aa_entropy_mean"] = float("nan") + row["aa_max_prob_mean"] = float("nan") + + # Backbone perplexity via ProteinMPNN (if model provided) + if pmpnn_model is not None: + try: + from proteinfoundation.metrics.backbone_perplexity import ( + backbone_perplexity, + ) + bp = backbone_perplexity(pdb_path, model=pmpnn_model) + row["bb_perplexity"] = bp["perplexity"] + row["bb_mean_nll"] = bp["mean_nll"] + except Exception as e: + logger.warning(f"Perplexity failed for {pdb_path}: {e}") + row["bb_perplexity"] = float("nan") + row["bb_mean_nll"] = float("nan") + else: + row["bb_perplexity"] = float("nan") + row["bb_mean_nll"] = float("nan") + + results.append(row) + prev_ca = ca + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Fast trajectory evaluation (no GPU)") + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Generation output dir with trajectory and final PDBs", + ) + parser.add_argument( + "--output_csv", + type=str, + default=None, + help="Path for output CSV (default: /trajectory_fast_eval.csv)", + ) + parser.add_argument( + "--convergence_threshold", + type=float, + default=0.5, + help="RMSD threshold (Angstroms) for convergence detection (default: 0.5)", + ) + parser.add_argument( + "--perplexity", + action=argparse.BooleanOptionalAction, + default=None, + help="Compute ProteinMPNN backbone perplexity. Default: auto (on if GPU available, off otherwise). Use --no-perplexity to force off.", + ) + parser.add_argument( + "--perplexity_device", + type=str, + default=None, + help="Device for ProteinMPNN backbone perplexity (default: auto-detect)", + ) + args = parser.parse_args() + + input_dir = args.input_dir + if not os.path.isdir(input_dir): + raise ValueError(f"Input directory does not exist: {input_dir}") + + output_csv = args.output_csv or os.path.join(input_dir, "trajectory_fast_eval.csv") + + samples = discover_samples(input_dir) + if not samples: + logger.error(f"No PDB files found in {input_dir}") + sys.exit(1) + + total_pdbs = sum( + len(s["trajectory"]) + (1 if s["final"] else 0) for s in samples.values() + ) + logger.info(f"Found {len(samples)} samples with {total_pdbs} total PDBs") + + # Auto-detect: enable perplexity by default if a GPU is available + import torch + has_gpu = torch.cuda.is_available() + run_perplexity = args.perplexity if args.perplexity is not None else has_gpu + perplexity_device = args.perplexity_device or ("cuda" if has_gpu else "cpu") + + pmpnn_model = None + if run_perplexity: + from proteinfoundation.metrics.backbone_perplexity import load_proteinmpnn + logger.info(f"Loading ProteinMPNN for backbone perplexity (device={perplexity_device})") + pmpnn_model = load_proteinmpnn(device=perplexity_device) + logger.info("ProteinMPNN loaded.") + + all_results = [] + for sample_idx in sorted(samples): + info = samples[sample_idx] + logger.info( + f" Sample {sample_idx}: {len(info['trajectory'])} trajectory steps" + + (f" + final" if info["final"] else "") + ) + all_results.extend(evaluate_sample(sample_idx, info, pmpnn_model=pmpnn_model)) + + df = pd.DataFrame(all_results) + df.to_csv(output_csv, index=False) + logger.info(f"Results saved to {output_csv}") + + # Print summary table per sample + has_logits = "aa_entropy_mean" in df.columns and df["aa_entropy_mean"].notna().any() + has_perplexity = "bb_perplexity" in df.columns and df["bb_perplexity"].notna().any() + + print("\n" + "=" * 130) + print("FAST TRAJECTORY EVALUATION") + print("=" * 130) + + for sample_idx in sorted(samples): + sdf = df[df["sample"] == sample_idx] + n = sdf["length"].iloc[0] + print(f"\nSample {sample_idx} (n={n}):") + header = ( + f" {'Step':>6} {'Label':<12} {'RMSD->final':>12} " + f"{'RMSD->prev':>12} {'Rg':>8} {'Helix%':>7} {'Sheet%':>7}" + ) + if has_logits: + header += f" {'AA Entropy':>11} {'AA MaxProb':>11}" + if has_perplexity: + header += f" {'BB Perpl':>9}" + print(header) + print(" " + "-" * (len(header) - 2)) + for _, row in sdf.iterrows(): + rmsd_final = f"{row['rmsd_to_final']:>12.3f}" if not np.isnan(row["rmsd_to_final"]) else f"{'--':>12}" + rmsd_prev = f"{row['rmsd_to_prev']:>12.3f}" if not np.isnan(row["rmsd_to_prev"]) else f"{'--':>12}" + line = ( + f" {row['step']:>6} {row['label']:<12} {rmsd_final} " + f"{rmsd_prev} {row['radius_of_gyration']:>8.2f} " + f"{row['helix_frac']:>6.1%} {row['sheet_frac']:>6.1%}" + ) + if has_logits: + aa_ent = f"{row['aa_entropy_mean']:>11.3f}" if not np.isnan(row.get("aa_entropy_mean", float("nan"))) else f"{'--':>11}" + aa_max = f"{row['aa_max_prob_mean']:>11.3f}" if not np.isnan(row.get("aa_max_prob_mean", float("nan"))) else f"{'--':>11}" + line += f" {aa_ent} {aa_max}" + if has_perplexity: + bb_ppl = f"{row['bb_perplexity']:>9.2f}" if not np.isnan(row.get("bb_perplexity", float("nan"))) else f"{'--':>9}" + line += f" {bb_ppl}" + print(line) + + # Convergence analysis + traj_df = df[df["label"] != "final"].copy() + thresholds = [1.0, 0.5, 0.1] + primary_threshold = args.convergence_threshold + + def find_convergence_step(rmsd_series, threshold): + """Find first step where rmsd_to_prev < threshold and stays below for all subsequent steps.""" + values = rmsd_series.values + for i in range(len(values)): + if np.isnan(values[i]): + continue + if all(v < threshold for v in values[i:] if not np.isnan(v)): + return i + return None + + print("\n" + "-" * 110) + print("CONVERGENCE SUMMARY") + print("-" * 110) + + if not traj_df.empty: + max_step = int(traj_df["step"].max()) + + # Per-sample convergence + for sample_idx in sorted(samples): + sdf = traj_df[traj_df["sample"] == sample_idx].sort_values("step") + if sdf.empty: + continue + n = sdf["length"].iloc[0] + first_rmsd = sdf.iloc[0]["rmsd_to_final"] + last_rmsd = sdf.iloc[-1]["rmsd_to_final"] + r_first = f"{first_rmsd:>8.3f}" if not np.isnan(first_rmsd) else f"{'--':>8}" + r_last = f"{last_rmsd:>8.3f}" if not np.isnan(last_rmsd) else f"{'--':>8}" + + conv_parts = [] + for t in thresholds: + idx = find_convergence_step(sdf["rmsd_to_prev"], t) + if idx is not None: + conv_step = int(sdf.iloc[idx]["step"]) + pct = (max_step - conv_step) / max_step * 100 + conv_parts.append(f"<{t}A @ step {conv_step:>4}") + else: + conv_parts.append(f"<{t}A @ {'--':>7}") + + print( + f" Sample {sample_idx:>3} (n={n:>3}): " + f"RMSD->final {r_first} -> {r_last} " + f"{' | '.join(conv_parts)}" + ) + + # Aggregate adaptive stopping potential + print("\n" + "-" * 110) + print("ADAPTIVE STOPPING POTENTIAL") + print("-" * 110) + + for t in thresholds: + conv_steps = [] + for sample_idx in sorted(samples): + sdf = traj_df[traj_df["sample"] == sample_idx].sort_values("step") + if sdf.empty: + continue + idx = find_convergence_step(sdf["rmsd_to_prev"], t) + if idx is not None: + conv_steps.append(int(sdf.iloc[idx]["step"])) + + if conv_steps: + median_step = int(np.median(conv_steps)) + savings = (max_step - median_step) / max_step * 100 + n_converged = len(conv_steps) + n_total = len(samples) + marker = " <--" if t == primary_threshold else "" + print( + f" RMSD < {t:.1f} A: " + f"median convergence at step {median_step:>4}/{max_step} " + f"({savings:>4.0f}% savings) " + f"[{n_converged}/{n_total} samples converged]{marker}" + ) + else: + print(f" RMSD < {t:.1f} A: no samples converged") + + print("=" * 110 + "\n") + + +if __name__ == "__main__": + main() diff --git a/proteinfoundation/evaluate.py b/proteinfoundation/evaluate.py index b15ddffe..e8647a3b 100644 --- a/proteinfoundation/evaluate.py +++ b/proteinfoundation/evaluate.py @@ -7,6 +7,7 @@ # isort: split +import lightning as L import pandas as pd import torch from biotite.structure.io import load_structure @@ -47,24 +48,37 @@ def parse_cfg_for_table(cfg: Dict) -> Tuple[List[str], Dict]: return columns, flat_dict -def split_by_job(cfg: Dict, job_id: int, is_des: bool = True) -> List[str]: +def split_by_job(cfg: Dict, job_id: int, is_des: bool = True, search_dir: str = None, find_all: bool = False) -> List[str]: """ Split evaluation jobs by job id. For designability, select files starting with `job_{job_id}_`, as each eval job will start after the corresponding generation job finishes For FID, uniformly assign files to each job. We ususally only use 1 eval job for FID. + Args: + cfg: Config dict. + job_id: Job id for this evaluation split. + is_des: Whether this is for designability evaluation. + search_dir: Directory to search for PDB files. + find_all: If True, find all .pdb files regardless of naming convention + (useful when --input_dir is provided and files don't follow job_* pattern). + Returns: List of paths to where PDBs are stored (each PDB is at a different path). """ if is_des: sample_root_paths = [] - for root, dirs, files in os.walk(root_path): + for root, dirs, files in os.walk(search_dir): + # Skip trajectory subdirectories — those are evaluated by eval_trajectory.py + if "trajectory_step_" in root: + continue for file in files: - if file.startswith(f"job_{job_id}_") and file.endswith(".pdb"): + if not file.endswith(".pdb"): + continue + if find_all or file.startswith(f"job_{job_id}_"): sample_root_paths.append(os.path.join(root, file)) logger.info( - f"Job id {job_id} for designability or novelty evaluation for {len(sample_root_paths)} files starting with `job_{job_id}_`" + f"Job id {job_id} for designability or novelty evaluation for {len(sample_root_paths)} PDB files" ) else: raise NotImplementedError("New metrics not implemented.") @@ -398,12 +412,22 @@ def compute_traditional_metrics( args, cfg, config_name = parse_args_and_cfg() run_name = cfg.run_name_ ncpus = cfg.ncpus_ - root_path = setup( - cfg, create_root=False, config_name=config_name, job_id=args.job_id - ) + + if args.input_dir is not None: + # Use the user-provided directory instead of the config-derived path. + root_path = args.input_dir + if not os.path.exists(root_path): + raise ValueError(f"Input directory does not exist: {root_path}") + # Still need seeding from setup, but skip path construction. + cfg.seed = cfg.seed + args.job_id + L.seed_everything(cfg.seed) + else: + root_path = setup( + cfg, create_root=False, config_name=config_name, job_id=args.job_id + ) cfg_metric = cfg.generation.metric - + # Code for designability if cfg_metric.compute_designability: gen_njobs = cfg.get("gen_njobs", 1) @@ -411,7 +435,11 @@ def compute_traditional_metrics( assert ( gen_njobs == eval_njobs ), f"The numbers of generation and evaluation jobs for traditaional metrics should be equal." - samples_paths = split_by_job(cfg, args.job_id, is_des=True) + samples_paths = split_by_job( + cfg, args.job_id, is_des=True, + search_dir=root_path, + find_all=(args.input_dir is not None), + ) df = compute_traditional_metrics(cfg_metric, samples_paths, args.job_id, ncpus) if "motif_task_name" in cfg.generation.dataset: csv_filename = f"results_{config_name}_{cfg.generation.dataset.motif_task_name}_{args.job_id}.csv" diff --git a/proteinfoundation/flow_matching/product_space_flow_matcher.py b/proteinfoundation/flow_matching/product_space_flow_matcher.py index a99863a3..f82495d5 100644 --- a/proteinfoundation/flow_matching/product_space_flow_matcher.py +++ b/proteinfoundation/flow_matching/product_space_flow_matcher.py @@ -1,8 +1,10 @@ +import os from typing import Callable, Dict, Optional, Tuple, Union import lightning as L import torch from jaxtyping import Bool, Float +from loguru import logger from torch import Tensor from proteinfoundation.flow_matching.rdn_flow_matcher import RDNFlowMatcher @@ -502,6 +504,279 @@ def _add_clean_n_sim_tensor(batch, mode): ) # Adds guided tensor used for simulation ("guided_v", "guided_score", or whatever the base flow matcher uses) return nn_out + def adaptive_heun_simulation( + self, + batch: Dict, + predict_for_sampling: Callable, + max_nfe: int, + nsamples: int, + n: int, + self_cond: bool, + sampling_model_args: Dict[str, Dict], + device: torch.device, + save_trajectory_every: int = 0, + guidance_w: float = 1.0, + ag_ratio: float = 0.0, + atol: float = 1e-3, + rtol: float = 1e-2, + dt_init: float = 0.01, + dt_min: float = 1e-4, + dt_max: float = 0.1, + safety_factor: float = 0.9, + ) -> Tuple[Dict[str, Tensor], Dict]: + """ + Adaptive Heun-Euler embedded method for ODE integration from t=0 to t=1. + + Uses the Heun method (trapezoidal rule) for stepping and the difference + between the Euler trial and Heun-corrected steps as a free local + truncation error estimate. Step size is adjusted based on the error + ratio, with steps rejected when the error exceeds tolerance. + + All samples in a batch share the same adaptive schedule (max error + across samples drives acceptance/rejection). + + Args: + batch: Input batch dict. + predict_for_sampling: Callable that runs the neural network. + max_nfe: Maximum number of function evaluations (NFE budget). + nsamples: Number of samples to generate. + n: Protein length. + self_cond: Whether to use self-conditioning. + sampling_model_args: Model-specific sampling arguments. + device: Torch device. + save_trajectory_every: Save trajectory at NFE milestones (0 to disable). + guidance_w: Guidance weight. + ag_ratio: Autoguidance ratio. + atol: Absolute error tolerance. + rtol: Relative error tolerance. + dt_init: Initial step size. + dt_min: Minimum step size. + dt_max: Maximum step size. + safety_factor: Safety factor for step size adjustment. + + Returns: + x: Generated samples dict. + additional_info: Dict with mask, trajectory, solver info, NFE counts. + """ + # Squeeze batch-dim-1 tensors (same as full_simulation) + for key, value in batch.items(): + if ( + isinstance(value, torch.Tensor) + and value.dim() > 0 + and value.size(0) == 1 + ): + batch[key] = value.squeeze(0) + + # Use mask from batch if present, otherwise generate it + if "mask" in batch and batch["mask"] is not None: + mask = batch["mask"] + else: + mask = torch.ones(nsamples, n).long().bool().to(device) + assert mask.shape == (nsamples, n) + + trajectory = [] + + with torch.no_grad(): + use_bf16 = os.environ.get("LP_USE_BF16", "1") == "1" and device.type == "cuda" + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_bf16): + x = self.sample_noise( + n, + shape=(nsamples,), + device=device, + mask=mask, + ) + + t = 0.0 + dt = dt_init + nfe_used = 0 + steps_accepted = 0 + steps_rejected = 0 + x_1_pred = None # for self-conditioning + + while t < 1.0 and nfe_used + 2 <= max_nfe: + dt = min(dt, 1.0 - t) + + # --- First NN evaluation at (x, t) --- + t_tensor = { + dm: t * torch.ones(nsamples, device=device) + for dm in self.data_modes + } + batch["x_t"] = x + batch["t"] = t_tensor + batch["mask"] = mask + + if steps_accepted > 0 and self_cond and x_1_pred is not None: + batch["x_sc"] = x_1_pred + + nn_out_1 = self.get_clean_pred_n_guided_vector( + batch=batch, + predict_for_sampling=predict_for_sampling, + guidance_w=guidance_w, + ag_ratio=ag_ratio, + ) + nfe_used += 1 + + v1 = {} + for dm in self.data_modes: + v1[dm] = nn_out_1[dm].get("v_guided", nn_out_1[dm]["v"]) + + x_1_pred_1 = self.nn_out_to_clean_sample_prediction( + batch=batch, nn_out=nn_out_1 + ) + + # --- Trial Euler step --- + x_trial = {dm: x[dm] + v1[dm] * dt for dm in self.data_modes} + x_trial = self._apply_mask(x_trial, mask) + for dm in self.data_modes: + if sampling_model_args[dm]["simulation_step_params"].get("center_every_step", False): + x_trial[dm] = self.base_flow_matchers[dm]._force_zero_com(x_trial[dm], mask) + + # --- Second NN evaluation at (x_trial, t + dt) --- + t_trial_tensor = { + dm: (t + dt) * torch.ones(nsamples, device=device) + for dm in self.data_modes + } + batch["x_t"] = x_trial + batch["t"] = t_trial_tensor + if self_cond: + batch["x_sc"] = x_1_pred_1 + + nn_out_2 = self.get_clean_pred_n_guided_vector( + batch=batch, + predict_for_sampling=predict_for_sampling, + guidance_w=guidance_w, + ag_ratio=ag_ratio, + ) + nfe_used += 1 + + v2 = {} + for dm in self.data_modes: + v2[dm] = nn_out_2[dm].get("v_guided", nn_out_2[dm]["v"]) + + # --- Error estimate --- + # err = 0.5 * (v2 - v1) * dt (local truncation error of Euler) + err_ratio = 0.0 + for dm in self.data_modes: + err_dm = 0.5 * (v2[dm] - v1[dm]) * dt + x_scale = atol + rtol * torch.max( + torch.abs(x[dm]), torch.abs(x_trial[dm]) + ) + scaled = err_dm / x_scale # [b, n, d] + nres = mask.sum(dim=-1) # [b] + d = scaled.shape[-1] + per_sample = torch.sqrt( + (scaled**2 * mask[..., None]).sum(dim=(-1, -2)) / (nres * d) + ) # [b] + err_ratio = max(err_ratio, per_sample.max().item()) + + if err_ratio <= 1.0: + # Accept step: Heun update + x = { + dm: x[dm] + 0.5 * (v1[dm] + v2[dm]) * dt + for dm in self.data_modes + } + x = self._apply_mask(x, mask) + for dm in self.data_modes: + if sampling_model_args[dm]["simulation_step_params"].get("center_every_step", False): + x[dm] = self.base_flow_matchers[dm]._force_zero_com(x[dm], mask) + + t += dt + steps_accepted += 1 + + # Update x_1_pred from second (better) eval for self-conditioning + x_1_pred = self.nn_out_to_clean_sample_prediction( + batch=batch, nn_out=nn_out_2 + ) + + # Save trajectory at NFE milestones + if save_trajectory_every > 0 and nfe_used % save_trajectory_every == 0: + trajectory.append({ + "step": nfe_used, + "x": {k: v.detach().clone() for k, v in x.items()}, + }) + else: + steps_rejected += 1 + # On rejection: x_1_pred stays from last accepted step + + # --- Adjust dt --- + if err_ratio > 0: + dt *= safety_factor * (1.0 / err_ratio) ** 0.5 + else: + dt *= 2.0 # error is zero; double the step + dt = max(dt_min, min(dt, dt_max)) + + # If t < 1.0 and budget exhausted, do one final forced Heun step + if t < 1.0 and nfe_used + 2 <= max_nfe + 2: + dt_final = 1.0 - t + t_tensor = { + dm: t * torch.ones(nsamples, device=device) + for dm in self.data_modes + } + batch["x_t"] = x + batch["t"] = t_tensor + batch["mask"] = mask + if self_cond and x_1_pred is not None: + batch["x_sc"] = x_1_pred + + nn_out_1 = self.get_clean_pred_n_guided_vector( + batch=batch, + predict_for_sampling=predict_for_sampling, + guidance_w=guidance_w, + ag_ratio=ag_ratio, + ) + nfe_used += 1 + v1 = {dm: nn_out_1[dm].get("v_guided", nn_out_1[dm]["v"]) for dm in self.data_modes} + x_1_pred_1 = self.nn_out_to_clean_sample_prediction(batch=batch, nn_out=nn_out_1) + + x_trial = {dm: x[dm] + v1[dm] * dt_final for dm in self.data_modes} + x_trial = self._apply_mask(x_trial, mask) + for dm in self.data_modes: + if sampling_model_args[dm]["simulation_step_params"].get("center_every_step", False): + x_trial[dm] = self.base_flow_matchers[dm]._force_zero_com(x_trial[dm], mask) + + t_trial_tensor = { + dm: 1.0 * torch.ones(nsamples, device=device) + for dm in self.data_modes + } + batch["x_t"] = x_trial + batch["t"] = t_trial_tensor + if self_cond: + batch["x_sc"] = x_1_pred_1 + + nn_out_2 = self.get_clean_pred_n_guided_vector( + batch=batch, + predict_for_sampling=predict_for_sampling, + guidance_w=guidance_w, + ag_ratio=ag_ratio, + ) + nfe_used += 1 + v2 = {dm: nn_out_2[dm].get("v_guided", nn_out_2[dm]["v"]) for dm in self.data_modes} + + x = {dm: x[dm] + 0.5 * (v1[dm] + v2[dm]) * dt_final for dm in self.data_modes} + x = self._apply_mask(x, mask) + for dm in self.data_modes: + if sampling_model_args[dm]["simulation_step_params"].get("center_every_step", False): + x[dm] = self.base_flow_matchers[dm]._force_zero_com(x[dm], mask) + + t += dt_final + steps_accepted += 1 + x_1_pred = self.nn_out_to_clean_sample_prediction(batch=batch, nn_out=nn_out_2) + + logger.info( + f"Adaptive Heun: t_final={t:.6f}, NFEs={nfe_used}/{max_nfe}, " + f"steps_accepted={steps_accepted}, steps_rejected={steps_rejected}" + ) + + additional_info = { + "mask": mask, + "trajectory": trajectory, + "solver": "adaptive_heun", + "nfe_total": nfe_used, + "nsteps_taken": steps_accepted, + "nsteps_rejected": steps_rejected, + } + return x, additional_info + def full_simulation( self, batch: Dict, @@ -515,6 +790,7 @@ def full_simulation( save_trajectory_every: int = 0, guidance_w: float = 1.0, ag_ratio: float = 0.0, + solver: str = "euler", ) -> Dict[str, Tensor]: """ Generates samples by simulating the @@ -572,9 +848,32 @@ def full_simulation( mask = torch.ones(nsamples, n).long().bool().to(device) assert mask.shape == (nsamples, n) - if save_trajectory_every > 0: - # Create a list of trajectory dictionaries, one for each sample in batch - [{} for _ in range(nsamples)] + assert solver in ("euler", "heun", "adaptive_heun", "stochastic_heun"), f"Unknown solver: {solver}" + if solver in ("heun", "adaptive_heun"): + for dm in self.data_modes: + mode = sampling_model_args[dm]["simulation_step_params"]["sampling_mode"] + assert mode == "vf", ( + f"{solver} requires sampling_mode='vf' (ODE), but {dm} has '{mode}'" + ) + + if solver == "adaptive_heun": + adaptive_cfg = sampling_model_args.get("adaptive", {}) + return self.adaptive_heun_simulation( + batch=batch, + predict_for_sampling=predict_for_sampling, + max_nfe=nsteps * 2, + nsamples=nsamples, + n=n, + self_cond=self_cond, + sampling_model_args=sampling_model_args, + device=device, + save_trajectory_every=save_trajectory_every, + guidance_w=guidance_w, + ag_ratio=ag_ratio, + **adaptive_cfg, + ) + + trajectory = [] ts = { data_mode: get_schedule( @@ -595,69 +894,205 @@ def full_simulation( for data_mode in self.data_modes } - with torch.no_grad(): - x = self.sample_noise( - n, - shape=(nsamples,), - device=device, - mask=mask, - ) - - for step in range(nsteps): - t = { - data_mode: ts[data_mode][step] * torch.ones(nsamples, device=device) - for data_mode in self.data_modes - } - dt = { - data_mode: ts[data_mode][step + 1] - ts[data_mode][step] - for data_mode in self.data_modes - } - gt_step = { - data_mode: gt[data_mode][step] for data_mode in self.data_modes - } - - # Update the batch with current x, t, mask, etc. - batch["x_t"] = x - batch["t"] = t - batch["mask"] = mask - - # self conditioning - if step > 0 and self_cond: - batch["x_sc"] = x_1_pred - - # get clean prediction and guided vector - nn_out = self.get_clean_pred_n_guided_vector( - batch=batch, - predict_for_sampling=predict_for_sampling, - guidance_w=guidance_w, - ag_ratio=ag_ratio, + # Stochastic Heun needs gt at the trial points (t+dt) for drift computation + if solver == "stochastic_heun": + gt_trial = { + data_mode: get_gt( + t=ts[data_mode][1:], # [nsteps], gt at trial endpoints + mode=sampling_model_args[data_mode]["gt"]["mode"], + param=sampling_model_args[data_mode]["gt"]["p"], + clamp_val=sampling_model_args[data_mode]["gt"]["clamp_val"], ) - # Dict[data_mode, - # Dict[str, torch.Tensor] - # ] where str is some prediction ("x_1", "v", "score", "guided_v", "guided_score", ...) - # We just track all predictions this way + for data_mode in self.data_modes + } - x_1_pred = self.nn_out_to_clean_sample_prediction( - batch=batch, nn_out=nn_out - ) - # Dict[data_mode, torch.Tensor] - - simulation_step_params = { - data_mode: sampling_model_args[data_mode]["simulation_step_params"] - for data_mode in self.data_modes - } - x = self.simulation_step( - x_t=x, - nn_out=nn_out, - t=t, - dt=dt, - gt=gt_step, + with torch.no_grad(): + use_bf16 = os.environ.get("LP_USE_BF16", "1") == "1" and device.type == "cuda" + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_bf16): + x = self.sample_noise( + n, + shape=(nsamples,), + device=device, mask=mask, - simulation_step_params=simulation_step_params, ) + for step in range(nsteps): + t = { + data_mode: ts[data_mode][step] * torch.ones(nsamples, device=device) + for data_mode in self.data_modes + } + dt = { + data_mode: ts[data_mode][step + 1] - ts[data_mode][step] + for data_mode in self.data_modes + } + gt_step = { + data_mode: gt[data_mode][step] for data_mode in self.data_modes + } + + # Update the batch with current x, t, mask, etc. + batch["x_t"] = x + batch["t"] = t + batch["mask"] = mask + + # self conditioning + if step > 0 and self_cond: + batch["x_sc"] = x_1_pred + + # get clean prediction and guided vector + nn_out = self.get_clean_pred_n_guided_vector( + batch=batch, + predict_for_sampling=predict_for_sampling, + guidance_w=guidance_w, + ag_ratio=ag_ratio, + ) + # Dict[data_mode, + # Dict[str, torch.Tensor] + # ] where str is some prediction ("x_1", "v", "score", "guided_v", "guided_score", ...) + # We just track all predictions this way + + x_1_pred = self.nn_out_to_clean_sample_prediction( + batch=batch, nn_out=nn_out + ) + # Dict[data_mode, torch.Tensor] + + if solver == "heun": + # Extract v1 from first evaluation + v1 = {} + for dm in self.data_modes: + v1[dm] = nn_out[dm].get("v_guided", nn_out[dm]["v"]) + + # Trial Euler step + x_trial = {dm: x[dm] + v1[dm] * dt[dm] for dm in self.data_modes} + x_trial = self._apply_mask(x_trial, mask) + for dm in self.data_modes: + if sampling_model_args[dm]["simulation_step_params"].get("center_every_step", False): + x_trial[dm] = self.base_flow_matchers[dm]._force_zero_com(x_trial[dm], mask) + + # Second NN eval at trial point (x_trial, t+dt) + t_trial = { + dm: ts[dm][step + 1] * torch.ones(nsamples, device=device) + for dm in self.data_modes + } + batch["x_t"] = x_trial + batch["t"] = t_trial + if self_cond: + batch["x_sc"] = x_1_pred # use first eval's prediction + + nn_out_2 = self.get_clean_pred_n_guided_vector( + batch=batch, + predict_for_sampling=predict_for_sampling, + guidance_w=guidance_w, + ag_ratio=ag_ratio, + ) + + # Extract v2 and apply Heun correction + v2 = {} + for dm in self.data_modes: + v2[dm] = nn_out_2[dm].get("v_guided", nn_out_2[dm]["v"]) + + x = {dm: x[dm] + 0.5 * (v1[dm] + v2[dm]) * dt[dm] for dm in self.data_modes} + x = self._apply_mask(x, mask) + for dm in self.data_modes: + if sampling_model_args[dm]["simulation_step_params"].get("center_every_step", False): + x[dm] = self.base_flow_matchers[dm]._force_zero_com(x[dm], mask) + + # Update x_1_pred from second (better) eval for next step's self-conditioning + x_1_pred = self.nn_out_to_clean_sample_prediction( + batch=batch, nn_out=nn_out_2 + ) + + elif solver == "stochastic_heun": + # Stochastic Heun: trapezoidal correction on drift, + # same Brownian increment for trial and final steps. + v1 = {dm: nn_out[dm].get("v_guided", nn_out[dm]["v"]) for dm in self.data_modes} + + # Compute drift and diffusion at (x, t) + drift1 = {} + noise_term = {} + for dm in self.data_modes: + params = sampling_model_args[dm]["simulation_step_params"] + d1, dc1 = self.base_flow_matchers[dm].compute_drift_and_diffusion_coeff( + x[dm], v1[dm], t[dm], gt_step[dm], params + ) + drift1[dm] = d1 + eps_dm = torch.randn_like(x[dm]) + noise_term[dm] = dc1 * torch.sqrt(dt[dm]) * eps_dm + + # Trial Euler-Maruyama step (drift + same noise) + x_trial = {dm: x[dm] + drift1[dm] * dt[dm] + noise_term[dm] for dm in self.data_modes} + x_trial = self._apply_mask(x_trial, mask) + for dm in self.data_modes: + if sampling_model_args[dm]["simulation_step_params"].get("center_every_step", False): + x_trial[dm] = self.base_flow_matchers[dm]._force_zero_com(x_trial[dm], mask) + + # Second NN eval at trial point (x_trial, t+dt) + t_trial = { + dm: ts[dm][step + 1] * torch.ones(nsamples, device=device) + for dm in self.data_modes + } + batch["x_t"] = x_trial + batch["t"] = t_trial + if self_cond: + batch["x_sc"] = x_1_pred + + nn_out_2 = self.get_clean_pred_n_guided_vector( + batch=batch, + predict_for_sampling=predict_for_sampling, + guidance_w=guidance_w, + ag_ratio=ag_ratio, + ) + + # Compute drift at trial point + v2 = {dm: nn_out_2[dm].get("v_guided", nn_out_2[dm]["v"]) for dm in self.data_modes} + drift2 = {} + for dm in self.data_modes: + params = sampling_model_args[dm]["simulation_step_params"] + d2, _ = self.base_flow_matchers[dm].compute_drift_and_diffusion_coeff( + x_trial[dm], v2[dm], t_trial[dm], gt_trial[dm][step], params + ) + drift2[dm] = d2 + + # Stochastic Heun update: trapezoidal on drift, same noise realization + x = { + dm: x[dm] + 0.5 * (drift1[dm] + drift2[dm]) * dt[dm] + noise_term[dm] + for dm in self.data_modes + } + x = self._apply_mask(x, mask) + for dm in self.data_modes: + if sampling_model_args[dm]["simulation_step_params"].get("center_every_step", False): + x[dm] = self.base_flow_matchers[dm]._force_zero_com(x[dm], mask) + + x_1_pred = self.nn_out_to_clean_sample_prediction( + batch=batch, nn_out=nn_out_2 + ) + + else: + # Existing Euler path + simulation_step_params = { + data_mode: sampling_model_args[data_mode]["simulation_step_params"] + for data_mode in self.data_modes + } + x = self.simulation_step( + x_t=x, + nn_out=nn_out, + t=t, + dt=dt, + gt=gt_step, + mask=mask, + simulation_step_params=simulation_step_params, + ) + + if save_trajectory_every > 0 and (step + 1) % save_trajectory_every == 0: + trajectory.append({ + "step": step + 1, + "x": {k: v.detach().clone() for k, v in x.items()}, + }) + additional_info = { "mask": mask, + "trajectory": trajectory, + "solver": solver, + "nfe_total": nsteps * (2 if solver in ("heun", "stochastic_heun") else 1), } return x, additional_info diff --git a/proteinfoundation/flow_matching/rdn_flow_matcher.py b/proteinfoundation/flow_matching/rdn_flow_matcher.py index 121a4fcc..2dd44747 100644 --- a/proteinfoundation/flow_matching/rdn_flow_matcher.py +++ b/proteinfoundation/flow_matching/rdn_flow_matcher.py @@ -411,6 +411,93 @@ def simulation_step( return x_next + def compute_drift_and_diffusion_coeff( + self, + x_t: Float[Tensor, "* n d"], + v: Float[Tensor, "* n d"], + t: Float[Tensor, "*"], + gt: float, + simulation_step_params: Dict, + ) -> Tuple[Float[Tensor, "* n d"], float]: + """ + Decomposes the SDE step into drift and diffusion coefficient. + + For the SDE: dx = drift(x,t) dt + diffusion_coeff(t) dW + + This mirrors the logic of simulation_step but returns the drift and + diffusion coefficient separately (needed for higher-order SDE solvers + like stochastic Heun). + + Args: + x_t: Current value, shape [*, n, d] + v: Velocity field (already guided if guidance is active), shape [*, n, d] + t: Current time, shape [*] + gt: Noise injection schedule value, float or 0-dim tensor + simulation_step_params: Parameters dict with sampling_mode, scales, etc. + + Returns: + drift: Tensor same shape as x_t (what multiplies dt) + diffusion_coeff: float or 0-dim tensor (sigma; noise std is sigma * sqrt(dt) * eps) + """ + sampling_mode = simulation_step_params["sampling_mode"] + sc_scale_noise = simulation_step_params["sc_scale_noise"] + sc_scale_score = simulation_step_params["sc_scale_score"] + t_lim_ode = simulation_step_params["t_lim_ode"] + t_lim_ode_below = simulation_step_params["t_lim_ode_below"] + t_element = t.flatten()[0] + + sc_scale_score_def = 1.5 + sc_scale_noise_def = 0.3 + + # Clamp t for vf_to_score/score_to_vf safety (they assert strict bounds) + t_safe = torch.clamp(t, min=1e-5, max=1.0 - 1e-5) + + if sampling_mode == "vf": + return v, 0.0 + + elif sampling_mode == "sc": + if t_element > t_lim_ode: + score = vf_to_score(x_t, v, t_safe) + v_scaled = score_to_vf(x_t, score * sc_scale_score_def, t_safe) + return v_scaled, 0.0 + else: + score = vf_to_score(x_t, v, t_safe) + drift = v + gt * score + diffusion_coeff = (2.0 * gt * sc_scale_noise) ** 0.5 + return drift, diffusion_coeff + + elif sampling_mode == "vf_ss": + if t_element < t_lim_ode_below: + score = vf_to_score(x_t, v, t_safe) + drift = v + gt * score + diffusion_coeff = (2.0 * gt * sc_scale_noise_def) ** 0.5 + return drift, diffusion_coeff + else: + score = vf_to_score(x_t, v, t_safe) + v_scaled = score_to_vf(x_t, score * sc_scale_score, t_safe) + return v_scaled, 0.0 + + elif sampling_mode == "vf_ss_sc_sn": + if t_element > t_lim_ode: + score = vf_to_score(x_t, v, t_safe) + v_scaled = score_to_vf(x_t, score * sc_scale_score_def, t_safe) + return v_scaled, 0.0 + elif t_element < t_lim_ode_below: + score = vf_to_score(x_t, v, t_safe) + drift = v + gt * score + diffusion_coeff = (2.0 * gt * sc_scale_noise_def) ** 0.5 + return drift, diffusion_coeff + else: + score = vf_to_score(x_t, v, t_safe) + v_scaled = score_to_vf(x_t, score * sc_scale_score, t_safe) + drift = v_scaled + gt * score + diffusion_coeff = (2.0 * gt * sc_scale_noise) ** 0.5 + return drift, diffusion_coeff + + else: + raise ValueError(f"Invalid sampling mode {sampling_mode}") + + def vf_to_score( x_t: Float[Tensor, "* n d"], v: Float[Tensor, "* n d"], diff --git a/proteinfoundation/generate.py b/proteinfoundation/generate.py index 9c30920e..6513ea47 100644 --- a/proteinfoundation/generate.py +++ b/proteinfoundation/generate.py @@ -61,9 +61,62 @@ def parse_args_and_cfg() -> Tuple[Dict, Dict, str]: type=str, help="Name of the data path", ) + parser.add_argument( + "--cpu", + action="store_true", + help="Run on CPU only (slow, for development/testing without GPU)", + ) + parser.add_argument( + "--input_dir", + type=str, + default=None, + help="Directory containing generated PDBs to evaluate (overrides config-derived path)", + ) + parser.add_argument( + "--nsteps", + type=int, + default=None, + help="Override number of generation steps (default: from config, usually 400)", + ) + parser.add_argument( + "--save_trajectory_every", + type=int, + default=None, + help="Save intermediate PDB structures every N steps (0 or omit to disable)", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Override output directory (default: ./inference/)", + ) + parser.add_argument( + "--nsamples", + type=int, + default=None, + help="Override number of samples to generate per length (default: from config)", + ) + parser.add_argument( + "--lengths", + type=int, + nargs="+", + default=None, + help="Override protein lengths to sample (e.g. --lengths 80 100 120)", + ) + parser.add_argument( + "--solver", + type=str, + choices=["euler", "heun", "adaptive_heun", "stochastic_heun"], + default=None, + help="Solver: euler, heun (ODE), adaptive_heun (ODE, adaptive), or stochastic_heun (SDE, 2nd order)", + ) args = parser.parse_args() if args.data_path is not None: os.environ["DATA_PATH"] = args.data_path + # Ensure DATA_PATH is set so OmegaConf interpolations don't fail during + # config serialization (e.g. Lightning's CSV logger). The actual path is + # only needed for training/evaluation, not generation. + os.environ.setdefault("DATA_PATH", ".") # Inference config # If config_subdir is None then use base inference config # Otherwise use config_subdir/some_config @@ -79,13 +132,27 @@ def parse_args_and_cfg() -> Tuple[Dict, Dict, str]: else: config_name = args.config_name cfg = hydra.compose(config_name=config_name) - logger.info(f"Inference config {cfg}") + + # CLI overrides for generation parameters + if args.nsteps is not None: + cfg.generation.args.nsteps = args.nsteps + if args.save_trajectory_every is not None: + cfg.generation.args.save_trajectory_every = args.save_trajectory_every + if args.nsamples is not None: + cfg.generation.dataset.nsamples = args.nsamples + if args.lengths is not None: + cfg.generation.dataset.nlens_cfg.nres_lens = args.lengths + if args.solver is not None: + cfg.generation.args.solver = args.solver + + logger.info(f"Inference config {cfg}") return args, cfg, config_name def setup( - cfg: Dict, create_root: bool = True, config_name: str = ".", job_id: int = 0 + cfg: Dict, create_root: bool = True, config_name: str = ".", job_id: int = 0, cpu_mode: bool = False, + output_dir: str = None, ) -> str: """ Checks if metrics being computed are compatible, sets the right seed, and creates the root directory @@ -96,9 +163,12 @@ def setup( """ logger.info(" ".join(sys.argv)) - assert ( - torch.cuda.is_available() - ), "CUDA not available" # Needed for ESMfold and designability + if cpu_mode: + logger.warning("Running in CPU mode - this will be slow!") + else: + assert ( + torch.cuda.is_available() + ), "CUDA not available. Use --cpu flag to run on CPU (slow)." logger.add( sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}", @@ -113,19 +183,25 @@ def setup( or not cfg.generation.metric.compute_fid ), "Designability/Novelty cannot be computed together with FID" - # Set root path for this inference run - if "motif_task_name" in cfg.generation.dataset: + # Set root path for this inference run, dated by default + from datetime import datetime + date_str = datetime.now().strftime("%Y-%m-%d_%H%M%S") + if output_dir is not None: + root_path = output_dir + elif "motif_task_name" in cfg.generation.dataset: root_path = ( - f"./inference/{config_name}_{cfg.generation.dataset.motif_task_name}" + f"./inference/{config_name}_{cfg.generation.dataset.motif_task_name}/{date_str}" ) else: - root_path = f"./inference/{config_name}" + root_path = f"./inference/{config_name}/{date_str}" if create_root: os.makedirs(root_path, exist_ok=True) else: if not os.path.exists(root_path): raise ValueError("Results path %s does not exist" % root_path) + logger.info(f"Output directory: {os.path.abspath(root_path)}") + # Set seed cfg.seed = cfg.seed + job_id # Different seeds for different splits ids logger.info(f"Seeding everything to seed {cfg.seed}") @@ -348,7 +424,8 @@ def main(): cfg.generation.args.get("fold_cond", False) njobs = cfg.get("gen_njobs", 1) root_path = setup( - cfg, create_root=True, config_name=config_name, job_id=args.job_id + cfg, create_root=True, config_name=config_name, job_id=args.job_id, cpu_mode=args.cpu, + output_dir=args.output_dir, ) # Exit if results from analysis already exist (assumes samples already there) @@ -366,6 +443,11 @@ def main(): # Load model model = load_ckpt_n_configure_inference(cfg) + # Optional: compile the transformer for faster inference + if os.environ.get("LP_USE_COMPILE", "0") == "1": + logger.info("Compiling model.nn with torch.compile (mode=reduce-overhead)") + model.nn = torch.compile(model.nn, mode="reduce-overhead") + # Create generation dataset cfg_gen = split_by_job(cfg_gen, args.job_id, njobs) @@ -410,8 +492,13 @@ def main(): dataloader = DataLoader(dataset, batch_size=1, shuffle=False) # Sample model - trainer = L.Trainer(accelerator="gpu", devices=1) - predictions = trainer.predict(model, dataloader) + accelerator = "cpu" if args.cpu else "gpu" + trainer = L.Trainer(accelerator=accelerator, devices=1) + raw_predictions = trainer.predict(model, dataloader) + + # predict_step returns (generation_list, trajectory_list) per batch. + predictions = [gen for gen, _traj in raw_predictions] + all_trajectories = [traj for _gen, traj in raw_predictions] chain_indexes = None @@ -436,6 +523,66 @@ def main(): cath_codes=dataset.cath_codes, ) + # Save trajectory snapshots, organised by sample (one directory per sample). + # Each batch (one per length) has local sample indices starting at 0, + # so we use a global offset to give each sample a unique index. + # Collect: per_sample[global_idx] = [(step, coors, aatype, logits_or_None), ...] + per_sample = defaultdict(list) + global_offset = 0 + for batch_traj in all_trajectories: + if not batch_traj: + continue + n_in_batch = len(batch_traj[0]["samples"]) + for snap in batch_traj: + step = snap["step"] + for j, sample_data in enumerate(snap["samples"]): + coors, aatype = sample_data[0], sample_data[1] + logits = sample_data[2] if len(sample_data) > 2 else None + per_sample[global_offset + j].append((step, coors, aatype, logits)) + global_offset += n_in_batch + + for idx, entries in sorted(per_sample.items()): + n = entries[0][1].shape[0] + sample_dir = os.path.join(root_path, f"job_{args.job_id}_n_{n:03d}_trajectory_{idx:03d}") + os.makedirs(sample_dir, exist_ok=True) + for step, coors, aatype, logits in entries: + fname = f"job_{args.job_id}_step_{step:04d}_n_{n:03d}_sample_{idx:03d}.pdb" + write_prot_to_pdb( + prot_pos=coors.float().detach().cpu().numpy(), + aatype=aatype.detach().cpu().numpy(), + file_path=os.path.join(sample_dir, fname), + overwrite=True, + no_indexing=True, + ) + # Save AA logits (softmax probabilities) as .npy alongside the PDB + if logits is not None: + import numpy as np + logits_path = os.path.join(sample_dir, fname.replace(".pdb", "_aa_logits.npy")) + np.save(logits_path, logits.float().detach().cpu().numpy()) + + # Write generation stats JSON sidecar + if hasattr(model, "_generation_stats") and model._generation_stats: + stats_path = os.path.join(root_path, "generation_stats.json") + with open(stats_path, "w") as f: + json.dump(model._generation_stats, f, indent=2) + logger.info(f"Generation stats written to {stats_path}") + + # Print summary with output path and eval command + abs_root = os.path.abspath(root_path) + n_pdbs = sum(1 for f in os.listdir(root_path) if f.endswith(".pdb")) + print("\n" + "=" * 70) + print("GENERATION COMPLETE") + print("=" * 70) + print(f" Output directory: {abs_root}") + print(f" PDB files generated: {n_pdbs}") + print(f"\n To evaluate this run:") + print(f" python proteinfoundation/eval_trajectory.py \\") + print(f" --input_dir {abs_root}") + print(f"\n Or on Modal:") + print(f" modal run infra/modal_app.py::eval_trajectory \\") + print(f" --input-dir {abs_root}") + print("=" * 70 + "\n") + if __name__ == "__main__": main() diff --git a/proteinfoundation/metrics/backbone_perplexity.py b/proteinfoundation/metrics/backbone_perplexity.py new file mode 100644 index 00000000..a8bd7919 --- /dev/null +++ b/proteinfoundation/metrics/backbone_perplexity.py @@ -0,0 +1,159 @@ +"""ProteinMPNN backbone perplexity: a fast, cheap proxy for designability. + +Backbone perplexity measures how well the backbone geometry constrains amino +acid identity. ProteinMPNN's `unconditional_probs` mode computes p(s_i | backbone) +for all 20 standard amino acids at each position in a single forward pass (no +sequence context, just backbone geometry). + +Low perplexity (2-5) = backbone strongly constrains AA choice = likely designable. +High perplexity (>10) = backbone is ambiguous or unphysical = likely not designable. + +Runs in ~0.1-1s per structure on CPU. Compare to full designability eval +(ProteinMPNN + ESMFold + scRMSD) which takes 5-30 min per structure on GPU. + +Usage: + from proteinfoundation.metrics.backbone_perplexity import ( + load_proteinmpnn, backbone_perplexity, + ) + model = load_proteinmpnn(device="cpu") + result = backbone_perplexity("path/to/protein.pdb", model=model) + print(f"Perplexity: {result['perplexity']:.2f}") +""" + +import math +import os +import sys + +import numpy as np +import torch + +# ProteinMPNN lives at /ProteinMPNN/ +_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +_PMPNN_DIR = os.path.join(_ROOT, "ProteinMPNN") + + +def _ensure_pmpnn_importable(): + if _PMPNN_DIR not in sys.path: + sys.path.insert(0, _PMPNN_DIR) + + +def load_proteinmpnn(model_name="v_48_020", device="cpu", ca_only=True): + """Load a ProteinMPNN model for backbone perplexity scoring. + + Args: + model_name: Weight file stem (default v_48_020, the 0.20A noise model). + device: torch device string. + ca_only: If True, use CA-only model (recommended for generated backbones + that may lack full backbone atoms). + + Returns: + ProteinMPNN model in eval mode on the specified device. + """ + _ensure_pmpnn_importable() + from protein_mpnn_utils import ProteinMPNN as PMPNNModel + + weight_dir = "ca_model_weights" if ca_only else "vanilla_model_weights" + ckpt_path = os.path.join(_PMPNN_DIR, weight_dir, f"{model_name}.pt") + if not os.path.exists(ckpt_path): + raise FileNotFoundError( + f"ProteinMPNN weights not found at {ckpt_path}. " + "Run script_utils/download_pmpnn_weights.sh first." + ) + + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) + + model = PMPNNModel( + ca_only=ca_only, + num_letters=21, + node_features=128, + edge_features=128, + hidden_dim=128, + num_encoder_layers=3, + num_decoder_layers=3, + augment_eps=0.0, # no noise at inference + k_neighbors=checkpoint["num_edges"], + ) + model.to(device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + return model + + +def featurize_pdb(pdb_path, device="cpu", ca_only=True): + """Parse a PDB and return tensors needed by ProteinMPNN.unconditional_probs. + + Returns: + (X, mask, residue_idx, chain_encoding_all) tensors. + """ + _ensure_pmpnn_importable() + from protein_mpnn_utils import parse_PDB, tied_featurize + + pdb_dict_list = parse_PDB(pdb_path, ca_only=ca_only) + + # All chains are "designed" (masked); no fixed chains. + all_chains = [ + item[-1:] for item in list(pdb_dict_list[0]) if item[:9] == "seq_chain" + ] + chain_dict = {pdb_dict_list[0]["name"]: (all_chains, [])} + + # tied_featurize returns a 20-element tuple; we need elements 0, 2, 5, 12. + result = tied_featurize( + batch=pdb_dict_list, + device=device, + chain_dict=chain_dict, + ca_only=ca_only, + ) + X = result[0] # [1, L, 3] for ca_only + mask = result[2] # [1, L] + chain_encoding_all = result[5] # [1, L] + residue_idx = result[12] # [1, L] + + return X, mask, residue_idx, chain_encoding_all + + +def backbone_perplexity(pdb_path, model=None, device="cpu", ca_only=True): + """Compute backbone perplexity for a PDB file. + + Args: + pdb_path: Path to PDB file. + model: Pre-loaded ProteinMPNN model (from load_proteinmpnn). + If None, loads a fresh model (slow; prefer pre-loading). + device: torch device string. + ca_only: Use CA-only features. + + Returns: + dict with: + mean_nll: Mean negative log-likelihood of the best AA per position. + perplexity: exp(mean_nll). + per_residue_nll: List of per-residue NLL values. + length: Number of residues. + """ + if model is None: + model = load_proteinmpnn(device=device, ca_only=ca_only) + + # Infer device from model parameters so inputs land on the same device + model_device = next(model.parameters()).device + X, mask, residue_idx, chain_encoding_all = featurize_pdb( + pdb_path, device=model_device, ca_only=ca_only + ) + + with torch.no_grad(): + log_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all) + + # log_probs: [1, L_padded, 21] (21 = 20 standard AAs + X) + L = int(mask.sum().item()) + + # NLL of the most likely AA at each position (exclude X at index 20) + log_probs_aa = log_probs[0, :L, :20] # [L, 20] + max_log_probs = log_probs_aa.max(dim=-1).values # [L] + per_residue_nll = (-max_log_probs).cpu().tolist() + + mean_nll = float(np.mean(per_residue_nll)) + perplexity = float(math.exp(mean_nll)) + + return { + "mean_nll": mean_nll, + "perplexity": perplexity, + "per_residue_nll": per_residue_nll, + "length": L, + } diff --git a/proteinfoundation/metrics/designability.py b/proteinfoundation/metrics/designability.py index beb7650a..6118fbf8 100644 --- a/proteinfoundation/metrics/designability.py +++ b/proteinfoundation/metrics/designability.py @@ -14,7 +14,7 @@ from transformers import logging as hf_logging from openfold.np import residue_constants -from proteinfoundation.metrics.folding_models import run_esmfold +from proteinfoundation.metrics.folding_models import load_esmfold, run_esmfold from proteinfoundation.utils.align_utils import kabsch_align_ind from proteinfoundation.utils.coors_utils import ( get_atom37_bb3_mask, @@ -311,6 +311,7 @@ def scRMSD( # cache_dir: Optional[str] = "/lustre/fsw/portfolios/nvr/users/kdidi/.cache", cache_dir: Optional[str] = None, keep_outputs: bool = False, + preloaded_esmfold: Optional[tuple] = None, ) -> Dict[str, Union[float, List[float]]]: """ Evaluates self-consistency RMSD metrics for given pdb. @@ -401,8 +402,13 @@ def scRMSD( os.makedirs(model_tmp_path, exist_ok=True) if model == "esmfold": + esm_model = None + esm_tokenizer = None + if preloaded_esmfold is not None: + esm_model, esm_tokenizer = preloaded_esmfold out_folding_paths = run_esmfold( - gen_seqs, model_tmp_path, name, suffix=suffix, cache_dir=cache_dir, keep_outputs=keep_outputs + gen_seqs, model_tmp_path, name, suffix=suffix, cache_dir=cache_dir, keep_outputs=keep_outputs, + esm_model=esm_model, tokenizer=esm_tokenizer, ) elif model == "colabfold": out_folding_paths = run_colabfold( diff --git a/proteinfoundation/metrics/folding_models.py b/proteinfoundation/metrics/folding_models.py index 659cd13f..db7d5043 100644 --- a/proteinfoundation/metrics/folding_models.py +++ b/proteinfoundation/metrics/folding_models.py @@ -61,6 +61,31 @@ def create_individual_fasta_files( return output_dir +def load_esmfold(cache_dir: Optional[str] = None): + """Load ESMFold model and tokenizer once, for reuse across multiple calls. + + Args: + cache_dir: Cache directory for model weights. + + Returns: + Tuple of (esm_model, tokenizer). + """ + is_cluster_run = os.environ.get("SLURM_JOB_ID") is not None + + final_cache_dir = cache_dir + if final_cache_dir is None and is_cluster_run: + final_cache_dir = os.environ.get("CACHE_DIR") + + tokenizer = AutoTokenizer.from_pretrained( + "facebook/esmfold_v1", cache_dir=final_cache_dir + ) + esm_model = EsmForProteinFolding.from_pretrained( + "facebook/esmfold_v1", cache_dir=final_cache_dir + ) + esm_model = esm_model.cuda() + return esm_model, tokenizer + + def run_esmfold( sequences: List[str], path_to_esmfold_out: str, @@ -68,6 +93,8 @@ def run_esmfold( suffix: str, cache_dir: Optional[str] = None, keep_outputs: bool = False, + esm_model: Optional[EsmForProteinFolding] = None, + tokenizer: Optional[AutoTokenizer] = None, ) -> List[str]: """ Runs ESMFold on sequences and stores results as PDB files. @@ -82,24 +109,14 @@ def run_esmfold( suffix: to use as suffix when storing files cache_dir: Cache directory for model weights keep_outputs: Whether to keep output directories + esm_model: Pre-loaded ESMFold model (if None, loads fresh) + tokenizer: Pre-loaded tokenizer (if None, loads fresh) Returns: List of paths (list of str) to PDB files """ - is_cluster_run = os.environ.get("SLURM_JOB_ID") is not None - - # Use provided cache_dir or fallback to environment/cluster logic - final_cache_dir = cache_dir - if final_cache_dir is None and is_cluster_run: - final_cache_dir = os.environ.get("CACHE_DIR") - - tokenizer = AutoTokenizer.from_pretrained( - "facebook/esmfold_v1", cache_dir=final_cache_dir - ) - esm_model = EsmForProteinFolding.from_pretrained( - "facebook/esmfold_v1", cache_dir=final_cache_dir - ) - esm_model = esm_model.cuda() + if esm_model is None or tokenizer is None: + esm_model, tokenizer = load_esmfold(cache_dir) # Run ESMFold list_of_strings_pdb = [] diff --git a/proteinfoundation/nn/modules/pair_bias_attn.py b/proteinfoundation/nn/modules/pair_bias_attn.py index eb3da0df..2fc12e5f 100644 --- a/proteinfoundation/nn/modules/pair_bias_attn.py +++ b/proteinfoundation/nn/modules/pair_bias_attn.py @@ -20,11 +20,13 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import os from typing import Optional import torch from einops import rearrange from torch import Tensor, einsum, nn +from torch.nn import functional as F from proteinfoundation.nn.modules.adaptive_ln_scale import ( AdaptiveLayerNorm, @@ -114,7 +116,25 @@ def forward( return self.to_out_node(attn_feats) def _attn(self, q, k, v, b, mask: Optional[Tensor]) -> Tensor: - """Perform attention update""" + """Perform attention update (dispatches to SDPA or manual implementation)""" + if os.environ.get("LP_USE_SDPA", "1") == "1": + return self._attn_sdpa(q, k, v, b, mask) + return self._attn_manual(q, k, v, b, mask) + + def _attn_sdpa(self, q, k, v, b, mask: Optional[Tensor]) -> Tensor: + """Attention via F.scaled_dot_product_attention (fused kernels, lower memory)""" + # Combine pair bias and padding mask into a single additive attn_mask + attn_bias = b if not isinstance(b, int) else None + if exists(mask): + mask_bias = torch.zeros( + mask.shape[0], 1, mask.shape[1], mask.shape[2], + dtype=q.dtype, device=q.device, + ).masked_fill(~rearrange(mask, "b i j -> b () i j"), float("-inf")) + attn_bias = mask_bias + attn_bias if attn_bias is not None else mask_bias + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + + def _attn_manual(self, q, k, v, b, mask: Optional[Tensor]) -> Tensor: + """Original manual attention implementation (fallback)""" sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale if exists(mask): mask = rearrange(mask, "b i j -> b () i j") diff --git a/proteinfoundation/partial_autoencoder/autoencoder.py b/proteinfoundation/partial_autoencoder/autoencoder.py index b3e804f8..c573aeb3 100644 --- a/proteinfoundation/partial_autoencoder/autoencoder.py +++ b/proteinfoundation/partial_autoencoder/autoencoder.py @@ -143,12 +143,15 @@ def decode( coors_nm = ( output_dec["coors_nm"] * mask[..., None, None] * atom_mask[..., None] ) # [b, n, 37, 3] - return { + result = { "coors_nm": coors_nm, "residue_type": output_dec["aatype_max"] * mask, "residue_mask": mask, "atom_mask": atom_mask, } + if "seq_logits" in output_dec: + result["seq_logits"] = output_dec["seq_logits"] # [b, n, 20] + return result def training_step(self, batch: Dict, batch_idx: int): """ diff --git a/proteinfoundation/partial_autoencoder/inference.py b/proteinfoundation/partial_autoencoder/inference.py index 34ed9733..8c875dc0 100644 --- a/proteinfoundation/partial_autoencoder/inference.py +++ b/proteinfoundation/partial_autoencoder/inference.py @@ -69,6 +69,11 @@ def parse_args_and_cfg() -> Tuple[Dict, Dict, str]: help="(Optional) Name of directory with config files, if not included uses base inference config.\ Likely only used when submitting to the cluster with script.", ) + parser.add_argument( + "--cpu", + action="store_true", + help="Run on CPU only (slow, for development/testing without GPU)", + ) args = parser.parse_args() # Inference config @@ -101,6 +106,7 @@ def setup( cfg: Dict, config_name: str, create_root: bool = True, + cpu_mode: bool = False, ) -> str: """ Checks if metrics being computed are compatible, sets the right seed, and creates the root directory @@ -111,9 +117,12 @@ def setup( """ logger.info(" ".join(sys.argv)) - assert ( - torch.cuda.is_available() - ), "CUDA not available" # Needed for ESMfold and designability + if cpu_mode: + logger.warning("Running in CPU mode - this will be slow!") + else: + assert ( + torch.cuda.is_available() + ), "CUDA not available. Use --cpu flag to run on CPU (slow)." logger.add( sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {message}", @@ -243,7 +252,7 @@ def main() -> None: ae_name, ckpt_name = extract_ckpt_info(cfg.ckpt_file) # Some setup - root_path = setup(cfg, create_root=True, config_name=config_name) + root_path = setup(cfg, create_root=True, config_name=config_name, cpu_mode=args.cpu) df_file_store = os.path.join(root_path, f"../results_{config_name}.csv") df_file_store_summary = os.path.join( root_path, f"../results_{config_name}_summary.csv" @@ -256,8 +265,9 @@ def main() -> None: model = AutoEncoder.load_from_checkpoint(cfg.ckpt_file) # Make predictions, store them together with inputs + accelerator = "cpu" if args.cpu else "gpu" trainer = L.Trainer( - accelerator="gpu", devices=1, limit_predict_batches=int(cfg.n_structs / cfg.bs) + accelerator=accelerator, devices=1, limit_predict_batches=int(cfg.n_structs / cfg.bs) ) predictions = trainer.predict(model, dataloader) # List of tuples, each tuple is (data_batch, predicted_batch) diff --git a/proteinfoundation/proteina.py b/proteinfoundation/proteina.py index 548eaa5a..4abbe78e 100644 --- a/proteinfoundation/proteina.py +++ b/proteinfoundation/proteina.py @@ -513,12 +513,29 @@ def predict_step(self, batch: Dict, batch_idx: int) -> List[Tuple[torch.tensor]] List of tuples. Each tuple represents one of the generated samples, has two elements: coordinates tensor of shape [n, 37, 3], and aatype tensor of shape [n]. """ + import time + self_cond = self.inf_cfg.args.self_cond nsteps = self.inf_cfg.args.nsteps guidance_w = self.inf_cfg.args.get("guidance_w", 1.0) ag_ratio = self.inf_cfg.args.get("ag_ratio", 0.0) - save_trajectory_every = 0 + save_trajectory_every = self.inf_cfg.args.get("save_trajectory_every", 0) + + # Timing and memory tracking + is_cuda = self.device.type == "cuda" + if is_cuda: + torch.cuda.reset_peak_memory_stats(self.device) + torch.cuda.synchronize(self.device) + t_start = time.time() + + solver = self.inf_cfg.args.get("solver", "euler") + + # Build sampling_model_args; include adaptive config if present + sampling_model_args = dict(self.inf_cfg.model) + adaptive_cfg = self.inf_cfg.args.get("adaptive", None) + if adaptive_cfg is not None: + sampling_model_args["adaptive"] = dict(adaptive_cfg) fn_predict_for_sampling = partial( self.predict_for_sampling, n_recycle=self.inf_cfg.get("n_recycle", 0) @@ -530,12 +547,48 @@ def predict_step(self, batch: Dict, batch_idx: int) -> List[Tuple[torch.tensor]] nsamples=batch["nsamples"], n=batch["nres"], self_cond=self_cond, - sampling_model_args=self.inf_cfg.model, + sampling_model_args=sampling_model_args, device=self.device, save_trajectory_every=save_trajectory_every, guidance_w=guidance_w, ag_ratio=ag_ratio, + solver=solver, + ) + + # Record timing and peak memory + if is_cuda: + torch.cuda.synchronize(self.device) + wall_time_s = time.time() - t_start + peak_memory_gb = ( + torch.cuda.max_memory_allocated(self.device) / 1e9 if is_cuda else 0.0 + ) + logger.info( + f"predict_step batch {batch_idx}: length={batch['nres']}, nsamples={batch['nsamples']}, " + f"nsteps={nsteps}, wall_time={wall_time_s:.1f}s, peak_memory={peak_memory_gb:.2f}GB" ) + nfe_total = extra_info.get("nfe_total", nsteps * (2 if solver in ("heun", "stochastic_heun") else 1)) + if not hasattr(self, "_generation_stats"): + self._generation_stats = [] + stats_entry = { + "batch_idx": batch_idx, + "length": int(batch["nres"]), + "nsamples": int(batch["nsamples"]), + "nsteps": nsteps, + "solver": solver, + "nfe_total": nfe_total, + "wall_time_s": round(wall_time_s, 2), + "peak_memory_gb": round(peak_memory_gb, 3), + "device": str(self.device), + "use_sdpa": os.environ.get("LP_USE_SDPA", "1") == "1", + "use_bf16": os.environ.get("LP_USE_BF16", "1") == "1", + "use_compile": os.environ.get("LP_USE_COMPILE", "0") == "1", + } + if "nsteps_taken" in extra_info: + stats_entry["nsteps_taken"] = extra_info["nsteps_taken"] + if "nsteps_rejected" in extra_info: + stats_entry["nsteps_rejected"] = extra_info["nsteps_rejected"] + self._generation_stats.append(stats_entry) + # Dict with the data_modes as keys, and values with batch shape b # extra_info is a dict with additional things, including # "mask", whose value is boolean of shape [nsamples, n] @@ -554,7 +607,25 @@ def predict_step(self, batch: Dict, batch_idx: int) -> List[Tuple[torch.tensor]] generation_list.append( (sample_prots["coors"][i], sample_prots["residue_type"][i]) ) # Tuple (coors [n, 37, 3], aatype [n]) - return generation_list # List of tupes (coors [n, 37, 3], aatype [n]) + + # Decode trajectory snapshots through the autoencoder if present. + trajectory_list = [] + for snap in extra_info.get("trajectory", []): + traj_extra = {**extra_info, "return_logits": save_trajectory_every > 0} + traj_prots = self.sample_formatting( + x=snap["x"], + extra_info=traj_extra, + ret_mode="coors37_n_aatype", + ) + step_samples = [] + for i in range(traj_prots["coors"].shape[0]): + sample_data = (traj_prots["coors"][i], traj_prots["residue_type"][i]) + if "seq_logits" in traj_prots: + sample_data = sample_data + (traj_prots["seq_logits"][i],) + step_samples.append(sample_data) + trajectory_list.append({"step": snap["step"], "samples": step_samples}) + + return generation_list, trajectory_list def sample_formatting( self, @@ -590,7 +661,8 @@ def sample_formatting( ) elif data_modes == ["bb_ca", "local_latents"]: return self._format_sample_local_latents( - x=x, ret_mode=ret_mode, mask=extra_info["mask"] + x=x, ret_mode=ret_mode, mask=extra_info["mask"], + return_logits=extra_info.get("return_logits", False), ) else: raise NotImplementedError(f"Format {ret_mode} not implemented") @@ -655,6 +727,7 @@ def _format_sample_local_latents( x: Dict[str, torch.Tensor], ret_mode: str, mask: Bool[torch.Tensor, "b n"], + return_logits: bool = False, ): """ Given a batch of b samples consisting on `bb_ca` and `local_latents` this @@ -689,11 +762,14 @@ def _format_sample_local_latents( return x elif ret_mode == "coors37_n_aatype": - return { + result = { "coors": nm_to_ang(output_decoder["coors_nm"]), # [b, n, 37, 3] "residue_type": output_decoder["residue_type"], # [b, n] "mask": output_decoder["residue_mask"], # [b, n] } + if return_logits and "seq_logits" in output_decoder: + result["seq_logits"] = output_decoder["seq_logits"] # [b, n, 20] + return result elif ret_mode == "pdb_string": pdb_strings = [] diff --git a/results/stochastic_heun_benchmark/comparison_report.md b/results/stochastic_heun_benchmark/comparison_report.md new file mode 100644 index 00000000..07945ab0 --- /dev/null +++ b/results/stochastic_heun_benchmark/comparison_report.md @@ -0,0 +1,100 @@ +# Stochastic Heun-100 vs SDE Euler Baseline + +**Date:** 2026-03-14 +**Config:** `inference_ucond_notri_stochastic_heun` (inherits `inference_ucond_notri`, sampling_mode: sc) + +## Method + +Stochastic Heun applies the trapezoidal rule to the SDE drift while keeping the +same Brownian increment for both the trial and corrected steps: + +``` +dW = sqrt(dt) * randn() # sample once +x_trial = x + f(x,t)*dt + s(t)*dW # Euler-Maruyama trial +f2 = drift(x_trial, t+dt) # 2nd function eval +x_new = x + 0.5*(f1+f2)*dt + s(t)*dW # trapezoidal correction on drift +``` + +Strong convergence order 1.0 (vs 0.5 for Euler-Maruyama). Works because the +diffusion coefficient depends only on t (additive noise), so no Ito correction +is needed. + +## Results Summary + +Outliers with scRMSD > 20A are excluded from mean/median computations. +Designability threshold: scRMSD < 2.0A. + +### Designability (ProteinMPNN + ESMFold scRMSD) + +| Run | NFEs | Samples | Outliers | Designable | Rate | Mean scRMSD | Median scRMSD | +|-----|------|---------|----------|------------|------|-------------|---------------| +| SDE Euler-400 (1) | 400 | 1890 | 374 | 1052/1516 | 69.4% | 4.679 A | 0.836 A | +| **Stoch-Heun-100** | **200** | **150** | **2** | **126/148** | **85.1%** | **1.449 A** | **0.695 A** | + +### Designability by Length + +| Length | Euler-400 Rate | Euler-400 Med | Stoch-Heun Rate | Stoch-Heun Med | +|--------|---------------|---------------|-----------------|----------------| +| L=100 | 60.5% (362/598) | 0.939 A | 100.0% (49/49) | 0.498 A | +| L=200 | 71.2% (353/496) | 0.810 A | 82.0% (41/50) | 0.759 A | +| L=300 | 79.9% (337/422) | 0.818 A | 73.5% (36/49) | 1.255 A | + +### Backbone Perplexity + +| Length | SDE Euler (2) | Stoch-Heun-100 | Delta | +|--------|---------------|----------------|-------| +| L=100 | 3.183 +/- 0.345 | 3.201 +/- 0.373 | +0.018 | +| L=200 | 3.087 +/- 0.352 | 3.226 +/- 0.461 | +0.139 | +| L=300 | 3.013 +/- 0.353 | 2.915 +/- 0.391 | -0.098 | +| **Overall** | **3.095** | **3.114** | **+0.019** | + +### Timing and Memory (A10G, 10 samples per batch, BF16+SDPA) + +| Length | SDE Euler-200 (2) | Stoch-Heun-100 | Speedup | +|--------|-------------------|----------------|---------| +| L=100 | 13.9s / 2.50 GB | 8.9s / 2.73 GB | 1.56x | +| L=200 | 40.5s / 4.96 GB | 28.3s / 4.89 GB | 1.43x | +| L=300 | 76.9s / 9.07 GB | 58.5s / 8.48 GB | 1.31x | +| **Total (150 samples)** | **657s / 9.07 GB** | **478s / 8.48 GB** | **1.37x** | + +Both Euler-200 and Stoch-Heun-100 use 200 NFEs. The speedup comes from fewer +Python loop iterations (100 vs 200) while the NN eval cost per NFE is the same. + +Estimated SDE Euler-400 (400 NFEs) total: ~1314s. Stoch-Heun-100 speedup vs +Euler-400: **~2.75x**. + +## Notes + +1. **SDE Euler-400 designability** is from a Feb 16 run (commit 3c5e5a9, same + model weights, 630 samples/length). The larger sample and older code version + may not be perfectly matched to the Stoch-Heun-100 run. A matched 50-sample + Euler-400 SDE eval would give a cleaner comparison. + +2. **Perplexity and timing baseline** comes from the Mar 12 benchmark_baseline + run, which used nsteps=200 (not 400). The perplexity comparison is therefore + Euler-200 vs Stoch-Heun-100 (both 200 NFEs); Euler-400 perplexity would be + slightly higher (more steps typically improves quality). + +3. **Memory usage** is comparable between methods; the stochastic Heun holds one + extra noise tensor per modality, which is negligible. + +4. The high outlier rate in the Euler-400 baseline (374/1890 = 19.8% with scRMSD + > 20A) versus Stoch-Heun-100 (2/150 = 1.3%) is a notable difference. This + could be due to the larger sample size or a genuine improvement from the + higher-order solver. + +## Reproduction + +```bash +# Full three-way benchmark (Euler-400 vs Heun-100 vs Stoch-Heun-100) +bash scripts/benchmark_solvers.sh 50 "100 200 300" + +# Or run Stochastic Heun only: +python proteinfoundation/generate.py \ + --config_name inference_ucond_notri_stochastic_heun \ + --nsamples 50 --lengths 100 200 300 \ + --output_dir ./inference/stochastic_heun + +python proteinfoundation/eval_trajectory.py \ + --input_dir ./inference/stochastic_heun +``` diff --git a/scripts/benchmark_solvers.sh b/scripts/benchmark_solvers.sh new file mode 100755 index 00000000..403d4d81 --- /dev/null +++ b/scripts/benchmark_solvers.sh @@ -0,0 +1,164 @@ +#!/bin/bash +# Benchmark solver comparison: Euler-400 vs Heun-100 (ODE) vs Stochastic-Heun-100 (SDE) +# +# Generates samples with each solver, runs both fast eval (structural metrics, +# seconds) and full designability eval (ProteinMPNN + ESMFold scRMSD, hours), +# then prints a side-by-side summary. +# +# Requires: GPU with CUDA, checkpoints in ./checkpoints_laproteina/ +# +# Usage: +# bash scripts/benchmark_solvers.sh # default: 10 samples, L=100,200,300 +# bash scripts/benchmark_solvers.sh 50 "100 200 300" # custom sample count and lengths +# +# Output goes to ./inference/benchmark// for each solver. + +set -euo pipefail + +NSAMPLES="${1:-10}" +LENGTHS="${2:-100 200 300}" + +BENCHMARK_DIR="./inference/benchmark" +CONFIG="inference_ucond_notri" + +SOLVERS=("euler" "heun" "stochastic_heun") +NSTEPS=(400 100 100) +LABELS=("Euler-400 (SDE baseline)" "Heun-100 (ODE, 200 NFEs)" "Stoch-Heun-100 (SDE, 200 NFEs)") + +echo "==============================================" +echo " Solver Benchmark" +echo "==============================================" +echo " Samples per length: $NSAMPLES" +echo " Lengths: $LENGTHS" +echo " Solvers: ${SOLVERS[*]}" +echo " Output: $BENCHMARK_DIR" +echo "==============================================" +echo "" + +# ── 1. Generate samples with each solver ───────────────────────────────────── + +for i in "${!SOLVERS[@]}"; do + solver="${SOLVERS[$i]}" + nsteps="${NSTEPS[$i]}" + label="${LABELS[$i]}" + outdir="${BENCHMARK_DIR}/${solver}" + + echo "$(date '+%H:%M:%S') Generating: $label (nsteps=$nsteps)..." + python proteinfoundation/generate.py \ + --config_name "$CONFIG" \ + --solver "$solver" \ + --nsteps "$nsteps" \ + --nsamples "$NSAMPLES" \ + --lengths $LENGTHS \ + --output_dir "$outdir" + + echo " -> Saved to $outdir" + echo "" +done + +# ── 2. Fast eval (structural metrics; seconds) ────────────────────────────── + +echo "==============================================" +echo " Fast Eval (structural metrics)" +echo "==============================================" +echo "" + +for i in "${!SOLVERS[@]}"; do + solver="${SOLVERS[$i]}" + label="${LABELS[$i]}" + outdir="${BENCHMARK_DIR}/${solver}" + + echo "$(date '+%H:%M:%S') Fast eval: $label ..." + python proteinfoundation/eval_trajectory_fast.py \ + --input_dir "$outdir" + + echo " -> ${outdir}/trajectory_fast_eval.csv" + echo "" +done + +# ── 3. Full designability eval (ProteinMPNN + ESMFold; hours) ──────────────── + +echo "==============================================" +echo " Full Eval (ProteinMPNN + ESMFold scRMSD)" +echo "==============================================" +echo "" + +for i in "${!SOLVERS[@]}"; do + solver="${SOLVERS[$i]}" + label="${LABELS[$i]}" + outdir="${BENCHMARK_DIR}/${solver}" + + echo "$(date '+%H:%M:%S') Full eval: $label ..." + python proteinfoundation/eval_trajectory.py \ + --input_dir "$outdir" + + echo " -> ${outdir}/trajectory_eval.csv" + echo "" +done + +# ── 4. Summary ─────────────────────────────────────────────────────────────── + +echo "==============================================" +echo " Benchmark Complete" +echo "==============================================" +echo "" + +python3 -c " +import pandas as pd +import os + +solvers = ['euler', 'heun', 'stochastic_heun'] +labels = ['Euler-400 (SDE)', 'Heun-100 (ODE)', 'Stoch-Heun-100 (SDE)'] +base = '${BENCHMARK_DIR}' + +# ── Designability (full eval) ── +rows = [] +for solver, label in zip(solvers, labels): + csv_path = os.path.join(base, solver, 'trajectory_eval.csv') + if not os.path.exists(csv_path): + continue + df = pd.read_csv(csv_path) + if 'step' in df.columns: + df = df.loc[df.groupby('sample')['step'].idxmax()] + n = len(df) + scrmsd_col = [c for c in df.columns if 'scrmsd' in c.lower() or 'sc_rmsd' in c.lower()] + if scrmsd_col: + scrmsd = df[scrmsd_col[0]] + valid = scrmsd[scrmsd <= 20.0] + designable = (valid < 2.0).sum() + rows.append({ + 'Solver': label, + 'Samples': n, + 'Outliers': n - len(valid), + 'Designable': f'{designable}/{len(valid)}', + 'Rate': f'{100*designable/len(valid):.1f}%' if len(valid) > 0 else 'N/A', + 'Median scRMSD': f'{valid.median():.3f}A', + }) + +if rows: + print('Designability (scRMSD < 2.0A, outliers > 20A excluded):') + print(pd.DataFrame(rows).to_string(index=False)) + print() + +# ── Structural metrics (fast eval) ── +rows2 = [] +for solver, label in zip(solvers, labels): + csv_path = os.path.join(base, solver, 'trajectory_fast_eval.csv') + if not os.path.exists(csv_path): + continue + df = pd.read_csv(csv_path) + if 'step' in df.columns: + df = df.loc[df.groupby('sample')['step'].idxmax()] + info = {'Solver': label, 'N': len(df)} + for col in ['radius_of_gyration', 'helix_fraction', 'sheet_fraction']: + if col in df.columns: + info[col] = f'{df[col].mean():.3f}' + rows2.append(info) + +if rows2: + print('Structural metrics (fast eval):') + print(pd.DataFrame(rows2).to_string(index=False)) +" 2>/dev/null || echo "(install pandas for automatic summary)" + +echo "" +echo "=============================================="