diff --git a/CHANGELOG.md b/CHANGELOG.md index 8da32d6..e16c179 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 7.1.0 + +### Added +- Use libdx for TRC and file conversion capabilities; conversion to/from SDF, to/from XYZ, and + `perceive_bonds` and `perceive_formal_charges` member functions on TRCS are now available + ## 7.0.0 No new changes since 7.0.0rc3. diff --git a/examples/exess-interaction-energy/05_exess_interaction_energy.py b/examples/exess-interaction-energy/05_exess_interaction_energy.py index 7872bd7..fc468bb 100644 --- a/examples/exess-interaction-energy/05_exess_interaction_energy.py +++ b/examples/exess-interaction-energy/05_exess_interaction_energy.py @@ -14,7 +14,7 @@ from pathlib import Path -from rush import RunOpts, exess, prepare +from rush import RunOpts, exess, get_fragments_near_fragment, prepare # ===== Example 1: Fragment-based interaction energy ===== print("=" * 60) @@ -80,7 +80,7 @@ # Step 2: Find ligand fragment index + nearby pocket lig_idx = trc.residues.seqs.index("MK1") -frag_idcs = trc.topology.get_fragments_near_fragment(lig_idx, 5.0) + [lig_idx] +frag_idcs = get_fragments_near_fragment(trc.topology, lig_idx, 5.0) + [lig_idx] # Step 3: Run interaction energy diff --git a/examples/exess-qmmm/06_exess_qmmm.py b/examples/exess-qmmm/06_exess_qmmm.py index ccb0734..35354cc 100644 --- a/examples/exess-qmmm/06_exess_qmmm.py +++ b/examples/exess-qmmm/06_exess_qmmm.py @@ -19,12 +19,13 @@ """ import json -from itertools import batched from pathlib import Path -from rush import RunOpts, Topology, exess +import numpy as np + +from rush import TRC, RunOpts, Topology, exess from rush.exess import Trajectory -from rush.mol import Element, Fragment, Residue, Residues +from rush.mol import Element DATA_DIR = Path(__file__).parent / "data" OUTPUT_DIR = Path(__file__).parent / "qmmm-outputs" @@ -72,7 +73,8 @@ out_traj = res.geometries # Load topology for atom info -topology = Topology.from_json(topology_path) +with open(topology_path) as f: + topology = Topology.from_dict(json.load(f)) assert topology.fragments, "Topology lost its fragments!" n_atoms = len(topology.symbols) @@ -119,6 +121,10 @@ def geometry_to_xyz(syms, geom, frame_label=""): return "\n".join(lines) +# NOTE: topology.geometry is an (N,3) numpy array, but out_traj geometries +# from the server are flat lists. geometry_to_xyz handles the flat format. + + # Build all frames as XYZ strings all_frames_xyz = [] for i, geom in enumerate(out_traj): @@ -337,38 +343,23 @@ def geometry_to_xyz(syms, geom, frame_label=""): print("Example 2: Minimal QM/MM (two water molecules)") print("=" * 60) -topology = Topology( +trc = TRC( symbols=[Element.O, Element.H, Element.H, Element.O, Element.H, Element.H], geometry=[ - 0.0000, - 0.0000, - 0.0000, - 0.7570, - 0.5860, - 0.0000, - -0.7570, - 0.5860, - 0.0000, - 2.8000, - 0.0000, - 0.0000, - 3.5570, - 0.5860, - 0.0000, - 2.0430, - 0.5860, - 0.0000, + [0.0000, 0.0000, 0.0000], + [0.7570, 0.5860, 0.0000], + [-0.7570, 0.5860, 0.0000], + [2.8000, 0.0000, 0.0000], + [3.5570, 0.5860, 0.0000], + [2.0430, 0.5860, 0.0000], ], - fragments=[Fragment([0, 1, 2]), Fragment([3, 4, 5])], -) - -residues = Residues( - residues=[Residue([0, 1, 2]), Residue([3, 4, 5])], - seqs=["HOH", "HOH"], + fragments=[[0, 1, 2], [3, 4, 5]], + residues=[[0, 1, 2], [3, 4, 5]], + residue_seqs=["HOH", "HOH"], ) run2 = exess.qmmm( - (topology, residues), + trc, n_timesteps=100, trajectory=Trajectory(include_waters=True), mm_fragments=[], @@ -386,10 +377,11 @@ def geometry_to_xyz(syms, geom, frame_label=""): out_traj = res2.geometries print("Atoms at First Step") -for atom_x, atom_y, atom_z in batched(topology.geometry, 3): - print(f" x: {atom_x:>7.4f}, y: {atom_y:>7.4f}, z: {atom_z:>7.4f}") +for x, y, z in trc.topology.geometry: + print(f" x: {x:>7.4f}, y: {y:>7.4f}, z: {z:>7.4f}") -topology.geometry = out_traj[-1] +final_geom = np.array(out_traj[-1], dtype=np.float32).reshape(-1, 3) +final_trc = TRC(symbols=list(trc.topology.symbols), geometry=final_geom.tolist()) print("Atoms at Final Step") -for atom_x, atom_y, atom_z in batched(topology.geometry, 3): - print(f" x: {atom_x:>7.4f}, y: {atom_y:>7.4f}, z: {atom_z:>7.4f}") +for x, y, z in final_trc.topology.geometry: + print(f" x: {x:>7.4f}, y: {y:>7.4f}, z: {z:>7.4f}") diff --git a/pyproject.toml b/pyproject.toml index b6570ee..1d087e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rush-py" -version = "7.0.0" +version = "7.1.0" description = "Python client for QDX's Rush platform" readme = "README.md" requires-python = ">=3.12" @@ -11,6 +11,7 @@ authors = [ dependencies = [ "gql~=4.0", "h5py~=3.14", + "libqdx~=0.8.0", "numpy>=1.26,<3", "rdkit~=2025.9.2", "matplotlib~=3.10", diff --git a/src/rush/__init__.py b/src/rush/__init__.py index 746054f..a89138d 100644 --- a/src/rush/__init__.py +++ b/src/rush/__init__.py @@ -17,24 +17,24 @@ ) from .mol import ( TRC, + Topology, + Residues, + Chains, + AlphaHelices, AminoAcidSeq, AtomRef, + BetaSheets, Bond, BondOrder, - Chain, ChainRef, - Chains, Element, - FormalCharge, - Fragment, FragmentRef, - PartialCharge, - Residue, - ResidueId, + HelixClass, ResidueRef, - Residues, - SchemaVersion, - Topology, + Stereochemistry, + StrandSense, + AtomCheckStrictness, + get_fragments_near_fragment, ) from .objects import ( ObjectID, @@ -87,26 +87,24 @@ "ObjectID", "TRCPaths", "TRCRef", - # Core structures + # Core types "TRC", "Topology", "Residues", "Chains", - # Chemistry types "Element", "Bond", "BondOrder", - "FormalCharge", - "PartialCharge", - "Fragment", - "FragmentRef", + "Stereochemistry", + "HelixClass", + "StrandSense", + "AlphaHelices", + "BetaSheets", + "AtomCheckStrictness", "AminoAcidSeq", - "SchemaVersion", - # Indices and records "AtomRef", - "Residue", "ResidueRef", - "ResidueId", - "Chain", "ChainRef", + "FragmentRef", + "get_fragments_near_fragment", ] diff --git a/src/rush/convert/__init__.py b/src/rush/convert/__init__.py index b12a3a7..5a38dbf 100644 --- a/src/rush/convert/__init__.py +++ b/src/rush/convert/__init__.py @@ -1,34 +1,86 @@ """ Conversion utilities for molecular structure file formats. -This module provides functions to convert between PDB, mmCIF, SDF, and QDX's TRC JSON formats. +Format parsing and writing are backed by the native libqdx Rust library. """ -import copy import json as std_json from collections.abc import Sequence from pathlib import Path from typing import TypeGuard +import libqdx + from ..mol import TRC from .json import from_json, to_dict -from .mmcif import from_mmcif -from .pdb import from_pdb, to_pdb -from .sdf import from_sdf -def load_structure(file_path: str | Path) -> TRC | list[TRC]: +def from_pdb(pdb_content: str) -> TRC | list[TRC]: + """Parse PDB file content into TRC structures. + + Args: + pdb_content: Raw PDB file text. + + Returns: + A single TRC if the file contains one model, otherwise a list of TRCs. + """ + trcs = libqdx.from_pdb(pdb_content) + return trcs[0] if len(trcs) == 1 else trcs + + +def to_pdb(trc: TRC) -> str: + """Convert a TRC structure to PDB format text. + + Args: + trc: TRC structure to serialise. + + Returns: + PDB-formatted string (includes trailing END record). + """ + return libqdx.to_pdb(trc) + + +def from_mmcif(mmcif_content: str) -> TRC | list[TRC]: + """Parse mmCIF file contents into TRC structures. + + Args: + mmcif_content: Raw mmCIF file text. + + Returns: + A single TRC if the file contains one model, otherwise a list of TRCs. """ - Load structure from PDB, mmCIF, or JSON file. + trcs = libqdx.from_mmcif(mmcif_content) + return trcs[0] if len(trcs) == 1 else trcs + + +def from_sdf(sdf_content: str) -> TRC | list[TRC]: + """Parse SDF file contents into TRC structures. Args: - file_path: Path to structure file + sdf_content: Raw SDF / MOL file text. Returns: - TRC structure or list of TRC structures + A single TRC if the file contains one molecule, otherwise a list of TRCs. + """ + trcs = libqdx.from_sdf(sdf_content) + return trcs[0] if len(trcs) == 1 else trcs + + +def load_structure(file_path: str | Path) -> TRC | list[TRC]: + """Load a molecular structure from a file. + + Supported formats: PDB, mmCIF (.cif / .mmcif), SDF, and TRC JSON. + The format is determined by extension; when the extension is + unrecognised the content is inspected heuristically. + + Args: + file_path: Path to the structure file. + + Returns: + A single TRC when the file contains one model/molecule, otherwise + a list of TRCs. """ path = Path(file_path) - # Determine file type by extension suffix = path.suffix.lower() if suffix == ".json": return from_json(path) @@ -39,8 +91,10 @@ def load_structure(file_path: str | Path) -> TRC | list[TRC]: return from_mmcif(content) elif suffix == ".pdb": return from_pdb(content) + elif suffix == ".sdf": + return from_sdf(content) else: - # Try to guess from content + # Unrecognised extension — try to guess from content content_lower = content.lower() if content.strip().startswith("[") or content.strip().startswith("{"): return from_json(std_json.loads(content)) @@ -53,23 +107,28 @@ def load_structure(file_path: str | Path) -> TRC | list[TRC]: def save_structure( trcs: TRC | list[TRC], file_path: str | Path, format: str | None = None ): - """ - Save TRC structures to file. + """Save TRC structures to a file. Args: - trcs: TRC structure or list of TRC structures - file_path: Output file path - format: Output format ('pdb', 'json', or None for auto-detect from extension) + trcs: TRC structure or list of TRC structures to write. + file_path: Output file path. + format: Output format (``'pdb'`` or ``'json'``). When *None* the + format is inferred from the file extension. + + Raises: + ValueError: If the format cannot be inferred or is unsupported. """ path = Path(file_path) if format is None: - # Auto-detect from extension - if path.suffix.lower() == ".json": + suffix = path.suffix.lower() + if suffix == ".json": format = "json" - elif path.suffix.lower() == ".pdb": + elif suffix == ".pdb": format = "pdb" else: - format = "pdb" # Default + raise ValueError( + f"Cannot infer format from extension '{suffix}'; pass format= explicitly" + ) if format.lower() == "json": with path.open("w") as f: @@ -79,7 +138,7 @@ def save_structure( if isinstance(trcs, TRC): trcs = [trcs] if len(trcs) > 1: - # Multi-model PDB + # Multi-model PDB: wrap each TRC in MODEL/ENDMDL records content_parts = [] for i, trc in enumerate(trcs, 1): content_parts.append(f"MODEL {i:>4}") @@ -129,16 +188,13 @@ def _load_trc(trc: TrcInput) -> TRC: if isinstance(trc, TRC): return trc if isinstance(trc, (str, Path)): - path = Path(trc) - if not path.exists(): - raise FileNotFoundError(f"TRC file not found: {trc}") - loaded: TRC | list[TRC] = from_json(path) + loaded = load_structure(trc) if _is_trc_list(loaded): if len(loaded) == 1: return loaded[0] - merged = copy.deepcopy(loaded[0]) + merged = loaded[0] for next_trc in loaded[1:]: - merged.extend(next_trc) + merged = merged.extend(next_trc) return merged if isinstance(loaded, list): raise TypeError("Expected TRC list elements to be TRC objects") @@ -152,44 +208,32 @@ def merge_trcs( skip_validation: bool = False, ) -> TRC: """ - Merge TRC objects into a single TRC. + Merge one or more TRC objects (or file paths) into a single TRC. - A TRC (Topology-Residues-Chains) object contains: - - topology: atom information (symbols, geometry, bonds, charges, etc.) - - residues: residue information (which atoms belong to which residues) - - chains: chain information (which residues belong to which chains) - - When merging, atom indices, residue indices, and chain indices are renumbered - to ensure uniqueness in the merged structure. + Atom, residue, and chain indices are renumbered so that the merged + structure has unique indices throughout. Args: - trcs: TRC objects or file paths. If a single list/tuple is provided, - it is treated as the full set of inputs. - output_file: Optional path to write the merged TRC JSON. - skip_validation: If True, skip validation of the merged TRC. + trcs: TRC objects or file paths. A single list/tuple is treated + as the full set of inputs. + output_file: Optional path to write the merged TRC as JSON. + skip_validation: If *True*, skip ``trc.check()`` on the result. Returns: - Merged TRC object. + The merged TRC object. Raises: ValueError: If no inputs are provided or validation fails. - FileNotFoundError: If file paths are provided but files don't exist. + FileNotFoundError: If a file path does not exist. """ trc_inputs = _normalize_trc_inputs(trcs) if not trc_inputs: raise ValueError("Expected at least one TRC input, found 0") - merged: TRC | None = None - for trc in trc_inputs: - trc_obj = _load_trc(trc) - if merged is None: - merged = copy.deepcopy(trc_obj) - else: - merged.extend(trc_obj) - - if merged is None: - raise ValueError("Expected at least one TRC input, found 0") + merged = _load_trc(trc_inputs[0]) + for trc in trc_inputs[1:]: + merged = merged.extend(_load_trc(trc)) if not skip_validation: merged.check() diff --git a/src/rush/convert/json.py b/src/rush/convert/json.py index cda5ce3..19275ce 100644 --- a/src/rush/convert/json.py +++ b/src/rush/convert/json.py @@ -8,20 +8,13 @@ from pathlib import Path from typing import TypeGuard, overload -from ..mol import TRC, Chains, Residues, Topology +from ..mol import TRC StrPath = PathLike[str] def _trcs_from_dicts(trc_dicts: list[dict]) -> list[TRC]: - return [ - TRC( - topology=Topology.from_json(trc_dict["topology"]), - residues=Residues.from_json(trc_dict["residues"]), - chains=Chains.from_json(trc_dict["chains"]), - ) - for trc_dict in trc_dicts - ] + return [TRC.from_dict(d) for d in trc_dicts] def _is_dict_list(value: object) -> TypeGuard[list[dict]]: @@ -102,25 +95,18 @@ def from_json( def to_dict( trcs: TRC | list[TRC], -) -> dict[str, dict[str, object]] | list[dict[str, dict[str, object]]]: +) -> dict | list[dict]: """ - Convert TRC structures to JSON. + Convert TRC structures to JSON-serializable dicts. Args: - trcs: TRC structure or list of TRC structures + trcs: A single TRC or a list of TRC structures. Returns: - JSON-compatible dict or list of dicts + A dict (if a single TRC was given) or a list of dicts, each + containing ``topology``, ``residues``, and ``chains`` keys. """ - if isinstance(trcs, TRC): trcs = [trcs] - data = [ - { - "topology": trc.topology.to_dict(), - "residues": trc.residues.to_dict(), - "chains": trc.chains.to_dict(), - } - for trc in trcs - ] + data = [trc.to_dict() for trc in trcs] return data[0] if len(data) == 1 else data diff --git a/src/rush/convert/mmcif.py b/src/rush/convert/mmcif.py deleted file mode 100644 index c5404e0..0000000 --- a/src/rush/convert/mmcif.py +++ /dev/null @@ -1,568 +0,0 @@ -""" -mmCIF file parsing functionality. -""" - -from collections import OrderedDict, defaultdict - -from ..mol import ( - TRC, - AtomRef, - Bond, - BondOrder, - Chain, - ChainRef, - Element, - FormalCharge, - Fragment, - Residue, - ResidueId, - ResidueRef, -) - - -def _parse_mmcif_value(value: str) -> str: - """Parse an mmCIF value, handling quoted strings and special characters.""" - value = value.strip() - if value in (".", "?"): - return "" - if value.startswith("'") and value.endswith("'"): - return value[1:-1] - if value.startswith('"') and value.endswith('"'): - return value[1:-1] - return value - - -def _parse_mmcif_loop( - lines: list[str], start_idx: int, prefix: str -) -> tuple[tuple[list[str], list[list[str]]] | None, int]: - """ - Parse an mmCIF loop starting at start_idx. - - Returns: - ((column_names, rows), next_idx) or (None, next_idx) if not a loop with the given prefix - """ - i = start_idx - - # Check if this is a loop - if i >= len(lines) or not lines[i].strip().startswith("loop_"): - return (None, i) - - i += 1 - - # Parse column names - columns = [] - while i < len(lines): - line = lines[i].strip() - if not line or line.startswith("#"): - i += 1 - continue - if not line.startswith("_"): - break - if line.startswith(prefix): - columns.append(line[len(prefix) :]) - elif ( - columns - ): # Started collecting columns for this prefix, now hit a different prefix - break - i += 1 - - if not columns: - return (None, i) - - # Parse data rows (may span multiple lines) - rows = [] - while i < len(lines): - line = lines[i].strip() - if not line or line.startswith("#"): - i += 1 - continue - if line.startswith("_") or line.startswith("loop_"): - break - - # Parse fields from current line (and additional lines if needed) - fields = [] - current_line = lines[i] - i += 1 - - while len(fields) < len(columns): - # Parse tokens from current_line - tokens = [] - j = 0 - current_line_stripped = current_line.rstrip("\n\r") - while j < len(current_line_stripped): - # Skip whitespace - while ( - j < len(current_line_stripped) and current_line_stripped[j] in " \t" - ): - j += 1 - if j >= len(current_line_stripped): - break - - # Check for quoted string - if current_line_stripped[j] in ("'", '"'): - quote_char = current_line_stripped[j] - j += 1 - start = j - while ( - j < len(current_line_stripped) - and current_line_stripped[j] != quote_char - ): - j += 1 - tokens.append(current_line_stripped[start:j]) - j += 1 # Skip closing quote - else: - # Unquoted value - start = j - while ( - j < len(current_line_stripped) - and current_line_stripped[j] not in " \t" - ): - j += 1 - tokens.append(current_line_stripped[start:j]) - - fields.extend(tokens) - - # If we don't have enough fields yet, try to read the next line - if len(fields) < len(columns): - if i < len(lines): - next_line = lines[i].strip() - if ( - next_line - and not next_line.startswith("_") - and not next_line.startswith("loop_") - and not next_line.startswith("data_") - ): - current_line = lines[i] - i += 1 - else: - break - else: - break - - if len(fields) == len(columns): - rows.append(fields) - - return ((columns, rows), i) - - -def _build_trc_from_mmcif_atoms( - atoms: list[dict], - struct_conn_data: tuple[list[str], list[list[str]]] | None, - comp_bond_data: tuple[list[str], list[list[str]]] | None, -) -> TRC: - """Build a TRC from parsed mmCIF atoms.""" - trc = TRC() - - atom_ids = [] - atom_labels = [] - atom_formal_charges = [] - atom_symbols = [] - geometry = [] - - residue_data = OrderedDict() - chain_data = defaultdict(set) - atom_index_map = {} # Original atom index to topology index - - for orig_idx, atom in enumerate(atoms): - # Only process atoms with alternate location "A" or None - alt_id = atom["label_alt_id"] - if alt_id and alt_id != "A": - continue - - # Parse element from type_symbol - type_symbol = atom["type_symbol"] - # Remove non-alphabetic characters - element_str = "".join(c for c in type_symbol if c.isalpha()) - try: - element = Element.from_str(element_str) - except (ValueError, KeyError): - element = Element.C # Default to carbon - - topology_idx = len(atom_symbols) - atom_index_map[orig_idx] = topology_idx - - atom_symbols.append(element) - geometry.extend([atom["Cartn_x"], atom["Cartn_y"], atom["Cartn_z"]]) - - atom_ids.append(atom["id"]) - atom_labels.append(atom["label_atom_id"]) - atom_formal_charges.append(atom["pdbx_formal_charge"]) - - # Create residue identifier using auth fields and "~" for sorting - residue_id = ResidueId( - chain_id=atom["auth_asym_id"], - sequence_number=atom["auth_seq_id"], - insertion_code=atom["pdbx_PDB_ins_code"] or "~", - residue_name=atom["label_comp_id"], - ) - - if residue_id not in residue_data: - residue_data[residue_id] = [] - residue_data[residue_id].append(len(atom_symbols) - 1) - - chain_data[atom["auth_asym_id"]].add(residue_id) - - # Build topology - trc.topology.symbols = atom_symbols - trc.topology.geometry = geometry - trc.topology.labels = atom_labels - trc.topology.formal_charges = [ - FormalCharge(charge) for charge in atom_formal_charges - ] - - # Sort residues by ResidueId - sorted_residue_ids = sorted( - residue_data.keys(), - key=lambda rid: ( - rid.chain_id, - rid.sequence_number, - rid.insertion_code, - rid.residue_name, - ), - ) - - # Build residues - residue_list = [] - seq_names = [] - seq_numbers = [] - insertion_codes_list = [] - - for residue_id in sorted_residue_ids: - atom_indices = residue_data[residue_id] - residue_list.append(Residue([AtomRef(idx) for idx in atom_indices])) - seq_names.append(residue_id.residue_name) - seq_numbers.append(residue_id.sequence_number) - # Convert "~" back to empty string - insertion_code = ( - "" if residue_id.insertion_code == "~" else residue_id.insertion_code - ) - insertion_codes_list.append(insertion_code) - - trc.residues.residues = residue_list - trc.residues.seqs = seq_names - trc.residues.seq_ns = seq_numbers - trc.residues.insertion_codes = insertion_codes_list - - # Build chains - chains = [] - residue_id_to_index = {rid: idx for idx, rid in enumerate(sorted_residue_ids)} - chain_ids = sorted(chain_data.keys()) - - for chain_id in chain_ids: - chain_residue_ids = chain_data[chain_id] - sorted_chain_residue_ids = sorted( - chain_residue_ids, - key=lambda rid: (rid.sequence_number, rid.insertion_code, rid.residue_name), - ) - - chain_residue_refs = [ - ResidueRef(residue_id_to_index[rid]) - for rid in sorted_chain_residue_ids - if rid in residue_id_to_index - ] - chains.append(Chain(chain_residue_refs)) - - trc.chains.chains = chains - trc.chains.labeled = [ChainRef(i) for i in range(len(chains))] - trc.chains.labels = [[chain_id] for chain_id in chain_ids] - - # Create fragments (one per residue) - trc.topology.fragments = [ - Fragment([AtomRef(atom_idx) for atom_idx in residue.atoms]) - for residue in trc.residues.residues - ] - - # Calculate fragment formal charges - fragment_formal_charges = [] - for residue in trc.residues.residues: - total_charge = sum(atom_formal_charges[atom_idx] for atom_idx in residue.atoms) - fragment_formal_charges.append(FormalCharge(total_charge)) - trc.topology.fragment_formal_charges = fragment_formal_charges - - # Build connectivity from struct_conn and chem_comp_bond - connectivity_deduper = {} # (min_idx, max_idx) -> bond_order - - # Parse struct_conn (inter-residue bonds) - if struct_conn_data: - columns, rows = struct_conn_data - col_idx = {col: idx for idx, col in enumerate(columns)} - - for row in rows: - - def get_val(name: str) -> str: - idx = col_idx.get(name) - if idx is not None and idx < len(row): - return _parse_mmcif_value(row[idx]) - return "" - - def get_int_val(name: str) -> int: - val = get_val(name) - try: - return int(val) if val else 0 - except ValueError: - return 0 - - # Find atoms by label (uses label_ fields, not auth_) - ptnr1_atom = get_val("ptnr1_label_atom_id") - ptnr1_asym = get_val("ptnr1_label_asym_id") - ptnr1_seq = get_int_val("ptnr1_label_seq_id") - ptnr2_atom = get_val("ptnr2_label_atom_id") - ptnr2_asym = get_val("ptnr2_label_asym_id") - ptnr2_seq = get_int_val("ptnr2_label_seq_id") - conn_type = get_val("conn_type_id") - - # Find matching atoms using label_ fields (find FIRST match like Rust .position()) - atom1_orig_idx = None - atom2_orig_idx = None - for idx, atom in enumerate(atoms): - if atom1_orig_idx is None and ( - atom["label_atom_id"] == ptnr1_atom - and atom["label_asym_id"] == ptnr1_asym - and atom["label_seq_id"] == ptnr1_seq - ): - atom1_orig_idx = idx - if atom2_orig_idx is None and ( - atom["label_atom_id"] == ptnr2_atom - and atom["label_asym_id"] == ptnr2_asym - and atom["label_seq_id"] == ptnr2_seq - ): - atom2_orig_idx = idx - if atom1_orig_idx is not None and atom2_orig_idx is not None: - break - - if atom1_orig_idx is not None and atom2_orig_idx is not None: - topo_idx1 = atom_index_map.get(atom1_orig_idx) - topo_idx2 = atom_index_map.get(atom2_orig_idx) - - if topo_idx1 is not None and topo_idx2 is not None: - bond_order = 1 # Default to single bond - if conn_type in ["covale", "metalc", "disulf"]: - bond_order = 1 - - min_idx = min(topo_idx1, topo_idx2) - max_idx = max(topo_idx1, topo_idx2) - connectivity_deduper[(min_idx, max_idx)] = bond_order - - # Parse chem_comp_bond (intra-residue bonds) - if comp_bond_data: - columns, rows = comp_bond_data - col_idx = {col: idx for idx, col in enumerate(columns)} - - # Build mapping of comp_id -> bonds - comp_bonds = defaultdict(list) - for row in rows: - - def get_val(name: str) -> str: - idx = col_idx.get(name) - if idx is not None and idx < len(row): - return _parse_mmcif_value(row[idx]) - return "" - - comp_id = get_val("comp_id") - atom_id_1 = get_val("atom_id_1") - atom_id_2 = get_val("atom_id_2") - value_order = get_val("value_order") - - comp_bonds[comp_id].append((atom_id_1, atom_id_2, value_order)) - - # Group atoms by residue for efficient lookup - # Note: Rust uses (comp_id, auth_asym_id, auth_seq_id) without insertion code - residue_atoms = defaultdict( - list - ) # (comp_id, auth_asym_id, auth_seq_id) -> list of (orig_idx, topo_idx, atom) - for orig_idx, atom in enumerate(atoms): - if atom["label_alt_id"] == "" or atom["label_alt_id"] == "A": - topo_idx = atom_index_map.get(orig_idx) - if topo_idx is not None: - key = ( - atom["label_comp_id"], - atom["auth_asym_id"], - atom["auth_seq_id"], - ) - residue_atoms[key].append((orig_idx, topo_idx, atom)) - - # Apply bond definitions to residues - for (comp_id, chain_id, seq_id), res_atoms in residue_atoms.items(): - if comp_id in comp_bonds: - for atom_id_1, atom_id_2, value_order in comp_bonds[comp_id]: - # Find THE FIRST atom that matches each atom_id (Rust uses find()) - topo_idx1 = None - topo_idx2 = None - for _, topo_idx, atom in res_atoms: - if topo_idx1 is None and atom["label_atom_id"] == atom_id_1: - topo_idx1 = topo_idx - if topo_idx2 is None and atom["label_atom_id"] == atom_id_2: - topo_idx2 = topo_idx - if topo_idx1 is not None and topo_idx2 is not None: - break - - if topo_idx1 is not None and topo_idx2 is not None: - # Parse bond order - bond_order = 1 - if value_order == "SING": - bond_order = 1 - elif value_order == "DOUB": - bond_order = 2 - elif value_order == "TRIP": - bond_order = 3 - elif value_order == "QUAD": - bond_order = 4 - elif value_order == "AROM": - bond_order = 5 - - min_idx = min(topo_idx1, topo_idx2) - max_idx = max(topo_idx1, topo_idx2) - connectivity_deduper[(min_idx, max_idx)] = bond_order - - # Convert to Bond objects - bonds = [] - for (min_idx, max_idx), order in sorted(connectivity_deduper.items()): - bonds.append(Bond(AtomRef(min_idx), AtomRef(max_idx), BondOrder(order))) - trc.topology.connectivity = bonds - - return trc - - -def from_mmcif(mmcif_content: str) -> TRC | list[TRC]: - """ - Parse mmCIF file contents into TRC structures. - - Args: - mmcif_content: String contents of an mmCIF file - - Returns: - TRC structure or list of TRC structures - """ - lines = mmcif_content.split("\n") - trcs = [] - - # Parse loops - models = defaultdict(list) # model_num -> list of atoms - atom_loop_data = None - struct_conn_data = None - comp_bond_data = None - - i = 0 - while i < len(lines): - if lines[i].strip().startswith("loop_"): - # Try to parse atom_site loop - result, next_i = _parse_mmcif_loop(lines, i, "_atom_site.") - if result: - columns, rows = result - # Check if this has atom_site columns - if any("id" in col or "type_symbol" in col for col in columns): - atom_loop_data = (columns, rows) - i = next_i - continue - - # Try to parse struct_conn loop - result, next_i = _parse_mmcif_loop(lines, i, "_struct_conn.") - if result: - struct_conn_data = result - i = next_i - continue - - # Try to parse chem_comp_bond loop - result, next_i = _parse_mmcif_loop(lines, i, "_chem_comp_bond.") - if result: - comp_bond_data = result - i = next_i - continue - - i = next_i - else: - i += 1 - - if not atom_loop_data: - empty_trc = TRC() - empty_trc.chains.labeled = [] - empty_trc.chains.labels = [] - return [empty_trc] - - columns, rows = atom_loop_data - - # Find column indices - col_idx = {} - for idx, col in enumerate(columns): - col_idx[col] = idx - - # Parse atoms - for row in rows: - if len(row) < len(columns): - continue - - def get_val(name: str, default: str = "") -> str: - idx = col_idx.get(name) - if idx is not None and idx < len(row): - val = _parse_mmcif_value(row[idx]) - return val if val else default - return default - - def get_int(name: str, default: int = 0) -> int | None: - val = get_val(name) - if not val: - return None - try: - return int(val) - except ValueError: - return None - - def get_int_with_default(name: str, default: int = 0) -> int: - val = get_int(name) - return val if val is not None else default - - def get_float(name: str, default: float = 0.0) -> float: - val = get_val(name) - try: - return float(val) if val else default - except ValueError: - return default - - # Parse auth_seq_id with fallback logic matching Rust - auth_seq_id_val = get_int("auth_seq_id") - if auth_seq_id_val is None: - auth_seq_id_val = get_int("label_seq_id") - if auth_seq_id_val is None: - auth_seq_id_val = 0 - - atom = { - "id": get_int_with_default("id", 0), - "type_symbol": get_val("type_symbol", "C"), - "label_atom_id": get_val("label_atom_id", "C"), - "label_alt_id": get_val("label_alt_id", ""), - "label_comp_id": get_val("label_comp_id", "UNK"), - "label_asym_id": get_val("label_asym_id", "A"), - "label_seq_id": get_int_with_default("label_seq_id", 0), - "pdbx_PDB_ins_code": get_val("pdbx_PDB_ins_code", ""), - "Cartn_x": get_float("Cartn_x", 0.0), - "Cartn_y": get_float("Cartn_y", 0.0), - "Cartn_z": get_float("Cartn_z", 0.0), - "occupancy": get_float("occupancy", 1.0), - "B_iso_or_equiv": get_float("B_iso_or_equiv", 0.0), - "pdbx_formal_charge": get_int_with_default("pdbx_formal_charge", 0), - "auth_asym_id": ( - get_val("auth_asym_id", "") or get_val("label_asym_id", "A") - ), - "auth_seq_id": auth_seq_id_val, - "group_PDB": get_val("group_PDB", "ATOM"), - "pdbx_PDB_model_num": get_val("pdbx_PDB_model_num", "1"), - } - - model_num = atom["pdbx_PDB_model_num"] - models[model_num].append(atom) - - # Build TRC for each model - for model_num in sorted(models.keys()): - atoms = models[model_num] - trc = _build_trc_from_mmcif_atoms(atoms, struct_conn_data, comp_bond_data) - trcs.append(trc) - - if not trcs: - empty_trc = TRC() - empty_trc.chains.labeled = [] - empty_trc.chains.labels = [] - trcs.append(empty_trc) - - if len(trcs) == 1: - return trcs[0] - return trcs diff --git a/src/rush/convert/pdb.py b/src/rush/convert/pdb.py deleted file mode 100644 index 0684afc..0000000 --- a/src/rush/convert/pdb.py +++ /dev/null @@ -1,536 +0,0 @@ -""" -PDB file parsing and writing functionality. -""" - -import sys -from collections import OrderedDict, defaultdict -from dataclasses import dataclass - -from ..mol import ( - TRC, - AminoAcidSeq, - AtomRef, - Bond, - BondOrder, - Chain, - ChainRef, - Element, - FormalCharge, - Fragment, - Residue, - ResidueId, - ResidueRef, -) - - -@dataclass -class PDBAtom: - """Represents a parsed PDB ATOM/HETATM record.""" - - atom_idx: int - atom_name: str - alternate_location: str | None - residue_name: str - chain_id: str - sequence_number: int - residue_insertion: str | None - atom_x: float - atom_y: float - atom_z: float - occupancy: float - temperature_factor: float - segment_id: str | None - element_symbol: Element - charge: int | None - - -def _parse_pdb_atom_line(line: str, line_num: int) -> PDBAtom: - """Parse a PDB ATOM or HETATM line.""" - if len(line) < 54: - raise ValueError(f"Line {line_num}: ATOM/HETATM line too short") - - try: - atom_idx = int(line[6:11].strip()) - atom_name = line[12:16].strip() - alternate_location = line[16].strip() if line[16].strip() else None - residue_name = line[17:20].strip() - chain_id = line[21].strip() if len(line) > 21 else "" - sequence_number = int(line[22:26].strip()) if line[22:26].strip() else 1 - residue_insertion = ( - line[26].strip() if len(line) > 26 and line[26].strip() else None - ) - - atom_x = float(line[30:38].strip()) if line[30:38].strip() else 0.0 - atom_y = float(line[38:46].strip()) if line[38:46].strip() else 0.0 - atom_z = float(line[46:54].strip()) if line[46:54].strip() else 0.0 - - occupancy = ( - float(line[54:60].strip()) - if len(line) > 60 and line[54:60].strip() - else 1.0 - ) - temperature_factor = ( - float(line[60:66].strip()) - if len(line) > 66 and line[60:66].strip() - else 0.0 - ) - - segment_id = ( - line[72:76].strip() if len(line) > 76 and line[72:76].strip() else None - ) - - element_symbol_str = ( - line[76:78].strip() - if len(line) > 78 and line[76:78].strip() - else atom_name[0] - ) - element_symbol = Element.from_str(element_symbol_str) - - charge = None - if len(line) > 80 and line[78:80].strip(): - charge_str = line[78:80].strip() - if charge_str: - # Parse charge like "+1", "-2", etc. - if charge_str[-1] in "+-": - sign = 1 if charge_str[-1] == "+" else -1 - magnitude = int(charge_str[:-1]) if charge_str[:-1] else 1 - charge = sign * magnitude - else: - charge = int(charge_str) - - return PDBAtom( - atom_idx=atom_idx, - atom_name=atom_name, - alternate_location=alternate_location, - residue_name=residue_name, - chain_id=chain_id, - sequence_number=sequence_number, - residue_insertion=residue_insertion, - atom_x=atom_x, - atom_y=atom_y, - atom_z=atom_z, - occupancy=occupancy, - temperature_factor=temperature_factor, - segment_id=segment_id, - element_symbol=element_symbol, - charge=charge, - ) - except (ValueError, IndexError) as e: - raise ValueError(f"Line {line_num}: Error parsing ATOM/HETATM line: {e}") - - -def _parse_conect_line(line: str) -> list[int]: - """Parse a CONECT line and return list of atom indices.""" - atom_idxs = [] - # CONECT format: positions 6-11, 11-16, 16-21, 21-26, 26-31 for atom indices - start = 6 - while start < len(line): - end = start + 5 - if end > len(line): - end = len(line) - atom_idx_str = line[start:end].strip() - if atom_idx_str: - try: - atom_idxs.append(int(atom_idx_str)) - except ValueError: - break - else: - break - start = end - return atom_idxs - - -def _build_trc( - atoms: list[PDBAtom], - atom_ids: list[int], - residue_data: OrderedDict, - chain_data: dict[str, set[ResidueId]], - connectivity: list[tuple[int, int, int]], -) -> TRC: - """Build a TRC structure from parsed PDB data.""" - - trc = TRC() - - # Build topology - trc.topology.symbols = [atom.element_symbol for atom in atoms] - trc.topology.geometry = [] - for atom in atoms: - trc.topology.geometry.extend([atom.atom_x, atom.atom_y, atom.atom_z]) - - trc.topology.labels = [atom.atom_name for atom in atoms] - - # Formal charges (per atom) - atom_formal_charges = [atom.charge or 0 for atom in atoms] - trc.topology.formal_charges = [ - FormalCharge(charge) for charge in atom_formal_charges - ] - - # Sort residues by ResidueId (chain_id, sequence_number, insertion_code, residue_name) - # This matches the Rust BTreeMap ordering - sorted_residue_ids = sorted( - residue_data.keys(), - key=lambda rid: ( - rid.chain_id, - rid.sequence_number, - rid.insertion_code, - rid.residue_name, - ), - ) - - # Build residues in sorted order - residue_list = [] - seq_names = [] - seq_numbers = [] - insertion_codes_list = [] - - for residue_id in sorted_residue_ids: - atom_indices = residue_data[residue_id] - residue_atoms = [AtomRef(idx) for idx in atom_indices] - residue_list.append(Residue(residue_atoms)) - seq_names.append(residue_id.residue_name) - seq_numbers.append(residue_id.sequence_number) - # Convert "~" back to empty string for storage - insertion_code = ( - "" if residue_id.insertion_code == "~" else residue_id.insertion_code - ) - insertion_codes_list.append(insertion_code) - - trc.residues.residues = residue_list - trc.residues.seqs = seq_names - trc.residues.seq_ns = seq_numbers - trc.residues.insertion_codes = insertion_codes_list - - # Build chains - chains = [] - residue_id_to_index = {rid: idx for idx, rid in enumerate(sorted_residue_ids)} - chain_ids = sorted(chain_data.keys()) - - for chain_id in chain_ids: - chain_residue_ids = chain_data[chain_id] - # Sort residues in chain by sequence number - sorted_residue_ids = sorted( - chain_residue_ids, key=lambda rid: (rid.sequence_number, rid.insertion_code) - ) - - chain_residue_refs = [ - ResidueRef(residue_id_to_index[rid]) for rid in sorted_residue_ids - ] - chains.append(Chain(chain_residue_refs)) - - trc.chains.chains = chains - trc.chains.labeled = [ChainRef(i) for i in range(len(chains))] - trc.chains.labels = [[chain_id] for chain_id in chain_ids] - - # Create fragments (one per residue) - amino acids as default fragments - trc.topology.fragments = [ - Fragment([AtomRef(atom_idx) for atom_idx in residue.atoms]) - for residue in trc.residues.residues - ] - - # Process connectivity - connectivity_deduper = {} # (origin, target) -> order - for origin_id, target_id, order in connectivity: - # Convert atom IDs to indices - try: - origin_idx = atom_ids.index(origin_id) - except ValueError: - continue - - try: - target_idx = atom_ids.index(target_id) - except ValueError: - continue - - # Check if reverse bond already exists (dedup) - if (target_idx, origin_idx) in connectivity_deduper: - continue - - # If same bond already exists, increment order (double bond) - if (origin_idx, target_idx) in connectivity_deduper: - connectivity_deduper[(origin_idx, target_idx)] += 1 - else: - connectivity_deduper[(origin_idx, target_idx)] = order - - # Convert to Bond objects - bonds = [] - for (origin_idx, target_idx), order in connectivity_deduper.items(): - bonds.append( - Bond( - AtomRef(min(origin_idx, target_idx)), - AtomRef(max(origin_idx, target_idx)), - BondOrder(order), - ) - ) - trc.topology.connectivity = bonds - - # Calculate fragment formal charges (sum of atom charges in each residue) - fragment_formal_charges = [] - for residue in trc.residues.residues: - total_charge = sum(atom_formal_charges[atom_idx] for atom_idx in residue.atoms) - fragment_formal_charges.append(FormalCharge(total_charge)) - trc.topology.fragment_formal_charges = fragment_formal_charges - - return trc - - -def _apply_global_connectivity( - trc: TRC, atom_ids: list[int], global_connectivity: list[tuple[int, int, int]] -): - """Apply global connectivity records to a TRC.""" - if not global_connectivity: - return - - connectivity_deduper = {} # (origin, target) -> order - - for origin_id, target_id, order in global_connectivity: - # Convert atom IDs to indices - try: - origin_idx = atom_ids.index(origin_id) - except ValueError: - continue - - try: - target_idx = atom_ids.index(target_id) - except ValueError: - continue - - # Check if reverse bond already exists (dedup) - if (target_idx, origin_idx) in connectivity_deduper: - continue - - # If same bond already exists, increment order (double bond) - if (origin_idx, target_idx) in connectivity_deduper: - connectivity_deduper[(origin_idx, target_idx)] += 1 - else: - connectivity_deduper[(origin_idx, target_idx)] = order - - # Convert to Bond objects - additional_bonds = [] - for (origin_idx, target_idx), order in connectivity_deduper.items(): - additional_bonds.append( - Bond( - AtomRef(min(origin_idx, target_idx)), - AtomRef(max(origin_idx, target_idx)), - BondOrder(order), - ) - ) - - # Add to existing connectivity - if trc.topology.connectivity: - trc.topology.connectivity.extend(additional_bonds) - else: - trc.topology.connectivity = additional_bonds - - -def from_pdb(pdb_content: str) -> TRC | list[TRC]: - """ - Parse PDB file content into TRC structures. - - Args: - pdb_content: String content of a PDB file - - Returns: - TRC structure or list of TRC structures (one per model in multi-model files) - """ - trcs = [] - trc_atom_ids = [] - global_connectivity = [] # List of (origin, target, order) tuples - - lines = pdb_content.strip().split("\n") - line_iter = iter(enumerate(lines, 1)) - - eof = False - while not eof: - # Storage for current model - atoms = [] - atom_ids = [] - residue_data = OrderedDict() # ResidueId -> atom indices - chain_data = defaultdict(set) # chain_id -> set of ResidueIds - connectivity = [] # Local connectivity for this model - - in_model = False - - while True: - try: - line_num, line = next(line_iter) - except StopIteration: - eof = True - break - - if len(line) < 6: - continue - - record_type = line[:6].strip() - - if record_type == "MODEL": - in_model = True - - elif record_type == "ENDMDL": - in_model = False - break - - elif record_type in ["ATOM", "HETATM"]: - in_model = True - - try: - atom = _parse_pdb_atom_line(line, line_num) - - # Only process atoms with alternate location "A" or None - # Skip atoms with other alternate locations (e.g., "B", "C", etc.) - if ( - atom.alternate_location is None - or atom.alternate_location == "A" - ): - atoms.append(atom) - atom_ids.append(atom.atom_idx) - - # Create residue identifier - # Note: insertion_code uses "~" for sorting (to sort after all letters) - # but the actual value stored in the residues structure is empty string - residue_id = ResidueId( - chain_id=atom.chain_id, - sequence_number=atom.sequence_number, - insertion_code=atom.residue_insertion or "~", - residue_name=atom.residue_name, - ) - - # Add to residue data - if residue_id not in residue_data: - residue_data[residue_id] = [] - residue_data[residue_id].append( - len(atoms) - 1 - ) # Index in atoms list - - # Add to chain data - chain_data[atom.chain_id].add(residue_id) - # else: skip atoms with other alternate locations - - except ValueError as e: - print(f"Warning: {e}", file=sys.stderr) - continue - - elif record_type == "CONECT": - try: - atom_idxs = _parse_conect_line(line) - if len(atom_idxs) >= 2: - origin = atom_idxs[0] - for target in atom_idxs[1:]: - if in_model: - connectivity.append((origin, target, 1)) - else: - global_connectivity.append((origin, target, 1)) - except (ValueError, IndexError): - continue - - elif record_type == "END": - break - - # If no atoms were found, skip this model - if not atoms: - if eof: - break - else: - continue - - # Build the TRC for this model - trc = _build_trc(atoms, atom_ids, residue_data, chain_data, connectivity) - trcs.append(trc) - trc_atom_ids.append(atom_ids) - - if eof: - break - - # Apply global connectivity to all models - for trc, atom_ids in zip(trcs, trc_atom_ids): - _apply_global_connectivity(trc, atom_ids, global_connectivity) - - # If no TRCs were created, return an empty one - if not trcs: - trcs.append(TRC()) - - if len(trcs) == 1: - return trcs[0] - return trcs - - -def to_pdb(trc: TRC) -> str: - """ - Convert TRC structure to PDB format string. - - Args: - trc: TRC structure to convert - - Returns: - PDB format string - """ - lines = [] - - # Create mapping from residue to chain - residue_to_chain = {} - for chain_idx, chain in enumerate(trc.chains.chains): - for residue_idx in chain.residues: - residue_to_chain[residue_idx] = chain_idx - - atom_idx = 1 - for residue_idx, residue in enumerate(trc.residues.residues): - chain_idx = residue_to_chain.get(residue_idx, 0) - chain_id = chr(65 + chain_idx) if chain_idx < 26 else "A" # A, B, C, ... - - residue_name = ( - trc.residues.seqs[residue_idx] - if residue_idx < len(trc.residues.seqs) - else "UNK" - ) - seq_num = ( - trc.residues.seq_ns[residue_idx] - if residue_idx < len(trc.residues.seq_ns) - else 1 - ) - insertion_code = ( - trc.residues.insertion_codes[residue_idx] - if residue_idx < len(trc.residues.insertion_codes) - else "" - ) - - for atom_idx in residue.atoms: - if atom_idx >= len(trc.topology.symbols): - continue - - element = trc.topology.symbols[atom_idx] - atom_name = ( - trc.topology.labels[atom_idx] if trc.topology.labels else str(element) - ) - - x = ( - trc.topology.geometry[atom_idx * 3] - if atom_idx * 3 < len(trc.topology.geometry) - else 0.0 - ) - y = ( - trc.topology.geometry[atom_idx * 3 + 1] - if atom_idx * 3 + 1 < len(trc.topology.geometry) - else 0.0 - ) - z = ( - trc.topology.geometry[atom_idx * 3 + 2] - if atom_idx * 3 + 2 < len(trc.topology.geometry) - else 0.0 - ) - - formal_charge = 0 - if trc.topology.formal_charges and atom_idx < len( - trc.topology.formal_charges - ): - formal_charge = trc.topology.formal_charges[atom_idx].charge - - # Format ATOM record - record_type = ( - "ATOM" if AminoAcidSeq.is_amino_acid(residue_name) else "HETATM" - ) - - line = f"{record_type:<6}{atom_idx:>5} {atom_name:<4} {residue_name:>3} {chain_id}{seq_num:>4}{insertion_code:<1} {x:>8.3f}{y:>8.3f}{z:>8.3f} 1.00 0.00 {str(element):>2}{formal_charge:+2d}" - lines.append(line) - atom_idx += 1 - - lines.append("END") - return "\n".join(lines) diff --git a/src/rush/convert/sdf.py b/src/rush/convert/sdf.py deleted file mode 100644 index 61d4802..0000000 --- a/src/rush/convert/sdf.py +++ /dev/null @@ -1,515 +0,0 @@ -""" -SDF file parsing functionality. - -Converts SDF (Structure Data File) format to TRC structures. -Supports SDF V2000 format. -""" - -from enum import Enum -from typing import Any - -from ..mol import ( - TRC, - AtomRef, - Bond, - BondOrder, - Chain, - ChainRef, - Element, - FormalCharge, - Fragment, - Residue, - ResidueRef, -) - - -class SDFParseState(Enum): - """Parser state machine states.""" - - HEADER_BLOCK = "HeaderBlock" - COUNTS_LINE = "CountsLine" - ATOM_BLOCK = "AtomBlock" - BOND_BLOCK = "BondBlock" - PROPERTIES_BLOCK = "PropertiesBlock" - DATA_ITEMS = "DataItems" - DONE = "Done" - - -class SDFPropertyType(Enum): - """SDF property types.""" - - CHARGE = "CHG" - END = "END" - UNK = "Unk" - - -# SDF bond types: 1=single, 2=double, 3=triple, 4=aromatic/ring -_SDF_BOND_TYPES = [1, 2, 3, 4] - - -def _charge_field_to_charge(c: int) -> int | None: - """Convert SDF charge field to actual charge value.""" - charge_map = { - 0: 0, - 1: 3, - 2: 2, - 3: 1, - 5: -1, - 6: -2, - 7: -3, - } - return charge_map.get(c) - - -def _bond_order_from_sdf(order: int) -> BondOrder: - """Convert SDF bond type to BondOrder.""" - if order == 1: - return BondOrder.Single - elif order == 2: - return BondOrder.Double - elif order == 3: - return BondOrder.Triple - elif order == 4: - return BondOrder.Ring - else: - raise ValueError(f"Invalid bond type: {order}") - - -def _parse_sdf_entry(sdf_content: str) -> dict[str, Any]: - """ - Parse a single SDF entry into a molecule dictionary. - - SDF V2000 format: - - Line 1: Molecule name - - Line 2: User/Program name - - Line 3: Comment - - Line 4: Counts line (num_atoms num_bonds ...) - - Lines 5-4+num_atoms: Atom block (x y z symbol ...) - - Lines 5+num_atoms-4+num_atoms+num_bonds: Bond block - - Properties block (optional, e.g., CHG for charges) - - Data items (optional, e.g., SMILES) - - Terminator: "$$$$" - """ - state = SDFParseState.HEADER_BLOCK - seen_chg_property = False - num_atoms = 0 - num_bonds = 0 - - molecule = { - "name": "", - "atoms": [], - "bonds": [], - "associated_data": [], - } - - lines = sdf_content.split("\n") - line_number = 0 - i = 0 - - while i < len(lines): - line = lines[i] - line_number = i + 1 - - # Skip empty lines (except in header block) - if not line.strip() and state != SDFParseState.HEADER_BLOCK: - i += 1 - continue - - if state == SDFParseState.HEADER_BLOCK: - molecule["name"] = line.strip() - # Skip next two lines (user/program and comment) - if i + 2 >= len(lines): - raise ValueError(f"Line {line_number + 1}: Missing header lines") - i += 3 # Skip header + 2 comment lines - state = SDFParseState.COUNTS_LINE - continue - - elif state == SDFParseState.COUNTS_LINE: - if "V3000" in line: - raise ValueError(f"Line {line_number}: V3000 format not supported") - - if len(line) < 6: - raise ValueError(f"Line {line_number}: Counts line too short") - - try: - num_atoms = int(line[:3].strip()) - num_bonds = int(line[3:6].strip()) - except ValueError as e: - raise ValueError(f"Line {line_number}: Could not parse counts: {e}") - - molecule["atoms"] = [] - molecule["bonds"] = [] - - state = SDFParseState.ATOM_BLOCK - i += 1 - continue - - elif state == SDFParseState.ATOM_BLOCK: - if len(line) < 39: - raise ValueError(f"Line {line_number}: Atom line too short") - - try: - x = float(line[0:10].strip()) - y = float(line[10:20].strip()) - z = float(line[20:30].strip()) - symbol = line[30:33].strip() - - # Mass difference (optional, at position 33-35) - # TODO: never used - _mass_diff = 0 - if len(line) >= 35: - try: - _mass_diff = int(line[33:35].strip() or "0") - except ValueError: - pass - - # Charge (at position 36-39, but SDF uses special encoding) - charge = 0 - if len(line) >= 39: - try: - charge_field = int(line[36:39].strip() or "0") - charge = _charge_field_to_charge(charge_field) - if charge is None: - charge = 0 - except ValueError: - charge = 0 - - molecule["atoms"].append( - { - "x": x, - "y": y, - "z": z, - "symbol": symbol, - "charge": charge, - } - ) - - if len(molecule["atoms"]) >= num_atoms: - if num_bonds == 0: - state = SDFParseState.PROPERTIES_BLOCK - else: - state = SDFParseState.BOND_BLOCK - - except (ValueError, IndexError) as e: - raise ValueError(f"Line {line_number}: Could not parse atom: {e}") - - i += 1 - continue - - elif state == SDFParseState.BOND_BLOCK: - if len(line) < 9: - raise ValueError(f"Line {line_number}: Bond line too short") - - try: - atom1 = ( - int(line[0:3].strip()) - 1 - ) # SDF is 1-indexed, convert to 0-indexed - atom2 = int(line[3:6].strip()) - 1 - bond_type = int(line[6:9].strip()) - bond_stereo = 0 - if len(line) >= 12: - try: - bond_stereo = int(line[9:12].strip() or "0") - except ValueError: - pass - - if bond_type not in _SDF_BOND_TYPES: - raise ValueError( - f"Line {line_number}: Invalid bond type: {bond_type}" - ) - - molecule["bonds"].append( - { - "atom1": atom1, - "atom2": atom2, - "bond_type": bond_type, - "bond_stereo": bond_stereo, - } - ) - - if len(molecule["bonds"]) >= num_bonds: - state = SDFParseState.PROPERTIES_BLOCK - - except (ValueError, IndexError) as e: - raise ValueError(f"Line {line_number}: Could not parse bond: {e}") - - i += 1 - continue - - elif state == SDFParseState.PROPERTIES_BLOCK: - if len(line) < 6: - # Might be empty line or start of data items - if not line.strip(): - state = SDFParseState.DATA_ITEMS - i += 1 - continue - else: - state = SDFParseState.DATA_ITEMS - continue - - try: - prop_type_str = line[3:6].strip() - if prop_type_str == "CHG": - prop_type = SDFPropertyType.CHARGE - elif prop_type_str == "END": - prop_type = SDFPropertyType.END - else: - prop_type = SDFPropertyType.UNK - - if prop_type == SDFPropertyType.CHARGE: - if not seen_chg_property: - # Reset all charges to 0 - for atom in molecule["atoms"]: - atom["charge"] = 0 - seen_chg_property = True - - # Parse charge count (position 6-8 or 6-9) - # apparently, the count block is 6..9 in some standards, and 6..8 in others - # lets determine this by checking if 9 is a space or not - count_end = 9 if len(line) > 8 and line[8:9].strip() else 8 - if len(line) < count_end: - raise ValueError(f"Line {line_number}: CHG line too short") - - count = int(line[6:count_end].strip()) - - if count == 0 or count > 8: - raise ValueError( - f"Line {line_number}: CHG count out of range: {count}" - ) - - # Parse charge entries from the same line - # Each entry is 8 characters: 4 for index, 4 for charge - for j in range(count): - start = count_end + 8 * j - end = count_end + 4 + 8 * j - - if len(line) < end: - raise ValueError( - f"Line {line_number}: CHG entry out of range" - ) - - atom_idx = ( - int(line[start:end].strip()) - 1 - ) # 1-indexed to 0-indexed - - start = count_end + 4 + 8 * j - end = count_end + 8 + 8 * j - - if len(line) < end: - raise ValueError( - f"Line {line_number}: CHG entry out of range" - ) - - charge = int(line[start:end].strip()) - - if atom_idx < 0 or atom_idx >= len(molecule["atoms"]): - raise ValueError( - f"Line {line_number}: CHG atom index out of range: {atom_idx}" - ) - if charge < -3 or charge > 3: - raise ValueError( - f"Line {line_number}: CHG charge out of range: {charge}" - ) - - molecule["atoms"][atom_idx]["charge"] = charge - - elif prop_type == SDFPropertyType.END: - state = SDFParseState.DATA_ITEMS - - except (ValueError, IndexError): - # If we can't parse as property, assume we're in data items - state = SDFParseState.DATA_ITEMS - continue - - i += 1 - continue - - elif state == SDFParseState.DATA_ITEMS: - if line.strip() == "$$$$": - # Terminator found - break - - if line.startswith(">"): - # Data item key - start = line.find("<") - if start == -1: - raise ValueError(f"Line {line_number}: Invalid data item format") - end = line.find(">", start) - if end == -1: - raise ValueError(f"Line {line_number}: Invalid data item format") - - key = line[start + 1 : end] - data = [] - - # Read data until empty line - i += 1 - while i < len(lines): - data_line = lines[i] - if not data_line.strip(): - break - data.append(data_line) - i += 1 - - molecule["associated_data"].append((key, "\n".join(data))) - continue - - i += 1 - continue - - return molecule - - -def _sdf_entries(sdf_content: str) -> list[tuple[int, str]]: - """Split SDF content into individual entries (separated by $$$$).""" - entries = [] - tail = sdf_content - current_line_number = 1 - - while True: - terminator_pos = tail.find("\n$$$$") - if terminator_pos == -1: - # Check if there's a $$$$ at the end - if tail.strip().endswith("$$$$"): - entries.append((current_line_number, tail)) - else: - raise ValueError( - f"Line {current_line_number}: Missing SDF terminator ($$$$)" - ) - break - - offset_after_terminator = terminator_pos + 5 - if len(tail) == offset_after_terminator: - entries.append((current_line_number, tail)) - break - elif ( - len(tail) > offset_after_terminator - and tail[offset_after_terminator] != "\n" - ): - raise ValueError(f"Line {current_line_number}: Invalid terminator format") - else: - entry = tail[: terminator_pos + 6] # Include \n$$$$ - tail = tail[terminator_pos + 6 :] - entries.append((current_line_number, entry)) - current_line_number += entry.count("\n") - if not tail.strip(): - break - - return entries - - -def _molecule_to_trc(molecule: dict[str, Any]) -> TRC: - """ - Convert a parsed molecule to TRC structure. - - Creates a TRC with: - - Single residue containing all atoms - - Residue name from molecule name (or "LIG" if empty) - - Single chain containing that residue - - Bonds as connectivity - - Charges as formal_charges - """ - trc = TRC() - - num_atoms = len(molecule["atoms"]) - - # Build topology - symbols = [] - geometry = [] - formal_charges = [] - labels = [] - - # Track element counts for labeling (e.g., C1, C2, N1, H1, H2...) - element_counts = {} - - for atom in molecule["atoms"]: - try: - element = Element.from_str(atom["symbol"]) - except ValueError: - element = Element.C # Default to carbon if unknown - symbols.append(element) - geometry.extend([float(atom["x"]), float(atom["y"]), float(atom["z"])]) - formal_charges.append(FormalCharge(atom["charge"])) - - # Create label based on element symbol and sequence number for that element - element_symbol = atom["symbol"] - element_counts[element_symbol] = element_counts.get(element_symbol, 0) + 1 - labels.append(f"{element_symbol}{element_counts[element_symbol]}") - - trc.topology.symbols = symbols - trc.topology.geometry = geometry - trc.topology.formal_charges = formal_charges - trc.topology.labels = labels - - # Build connectivity (bonds) - ensure atom1 < atom2 for canonical ordering - bonds = [] - for bond in molecule["bonds"]: - atom1_idx = bond["atom1"] - atom2_idx = bond["atom2"] - bond_order = _bond_order_from_sdf(bond["bond_type"]) - - # Ensure atom1 < atom2 (canonical ordering) - if atom1_idx > atom2_idx: - atom1_idx, atom2_idx = atom2_idx, atom1_idx - - bonds.append( - Bond( - AtomRef(atom1_idx), - AtomRef(atom2_idx), - bond_order, - ) - ) - - trc.topology.connectivity = bonds - - # Fragments: single fragment with all atoms - trc.topology.fragments = [Fragment([AtomRef(i) for i in range(num_atoms)])] - - # Fragment formal charge: sum of all atom charges - fragment_formal_charge = sum(atom["charge"] for atom in molecule["atoms"]) - trc.topology.fragment_formal_charges = [FormalCharge(fragment_formal_charge)] - - # Build residues: single residue with all atoms - trc.residues.residues = [Residue([AtomRef(i) for i in range(num_atoms)])] - trc.residues.seqs = ["LIG"] - trc.residues.seq_ns = [0] - trc.residues.insertion_codes = [""] - # Label the ligand residue - trc.residues.labeled = [ResidueRef(0)] - trc.residues.labels = [[molecule["name"].strip().lower() or "LIG"]] - - # Build chains: single chain with residue 0 - trc.chains.chains = [Chain([ResidueRef(0)])] - trc.chains.labeled = [ChainRef(0)] - trc.chains.labels = [[molecule["name"].strip() or "LIG"]] - - return trc - - -def from_sdf(sdf_content: str) -> TRC | list[TRC]: - """ - Parse SDF file contents into TRC structures. - - Args: - sdf_content: SDF file content as string - - Returns: - TRC structure or list of TRC structures (one per molecule in the SDF file) - - Raises: - ValueError: If SDF parsing fails - """ - entries = _sdf_entries(sdf_content) - trcs: list[TRC] = [] - - for line_number, entry in entries: - try: - molecule = _parse_sdf_entry(entry) - trcs.append(_molecule_to_trc(molecule)) - except Exception as e: - raise ValueError( - f"Error parsing SDF entry starting at line {line_number}: {e}" - ) - - if len(trcs) == 1: - return trcs[0] - return trcs diff --git a/src/rush/exess/_energy.py b/src/rush/exess/_energy.py index d148d98..ae681fe 100644 --- a/src/rush/exess/_energy.py +++ b/src/rush/exess/_energy.py @@ -628,10 +628,7 @@ def __post_init__(self): def _to_rex(self): included_fragments = None if self.included_fragments: - included_fragments = [ - f.value if isinstance(f, FragmentRef) else f - for f in self.included_fragments - ] + included_fragments = list(self.included_fragments) return Template( """Some (exess_rex::FragKeywords { cutoffs = Some (exess_rex::FragmentCutoffs { diff --git a/src/rush/exess/_optimization.py b/src/rush/exess/_optimization.py index 502ce44..dd930f0 100644 --- a/src/rush/exess/_optimization.py +++ b/src/rush/exess/_optimization.py @@ -275,7 +275,7 @@ def fetch(self) -> OptimizationResult: trajectory_raw = self.trajectory.fetch_list() steps_raw = self.steps.fetch_list() - trajectory = [Topology.from_json(t) for t in trajectory_raw] + trajectory = [Topology.from_dict(t) for t in trajectory_raw] steps = [OptimizationStep(**step) for step in steps_raw] return OptimizationResult(trajectory=trajectory, steps=steps) diff --git a/src/rush/mol.py b/src/rush/mol.py index 92f2720..b58e194 100644 --- a/src/rush/mol.py +++ b/src/rush/mol.py @@ -1,706 +1,128 @@ """ -Provides data structures and helpers for molecular systems and structures: +Molecular structure types for Rush. -- Classes for Rush Topology, Residues, Chains, and TRC types. -- Element types and bonds. -- Fragment type to represent fragmented systems. +Core types are provided by the native ``libqdx`` extension and re-exported +here for convenience. This module used to contain pure-Python dataclass +implementations; those have been replaced by opaque Rust-backed objects +from ``libqdx`` for performance and correctness. -Quick Links ------------ +Primary types +------------- +TRC + Combined **Topology + Residues + Chains** structure -- the main + representation for molecular systems on the Rush platform. Construct + via ``TRC.from_dict(d)`` or by loading a file through + ``rush.convert.load_structure``. -- :class:`rush.mol.TRC` -- :class:`rush.mol.Topology` -- :class:`rush.mol.Residues` -- :class:`rush.mol.Chains` -""" - -import json -import sys -from collections import defaultdict -from dataclasses import dataclass, field -from enum import Enum, IntEnum -from functools import total_ordering -from pathlib import Path -from typing import Self - - -class Element(IntEnum): - """Represents all relevant elements.""" - - X = 0 - H = 1 - He = 2 - Li = 3 - Be = 4 - B = 5 - C = 6 - N = 7 - O = 8 # noqa: E741 - F = 9 - Ne = 10 - Na = 11 - Mg = 12 - Al = 13 - Si = 14 - P = 15 - S = 16 - Cl = 17 - Ar = 18 - K = 19 - Ca = 20 - Sc = 21 - Ti = 22 - V = 23 - Cr = 24 - Mn = 25 - Fe = 26 - Co = 27 - Ni = 28 - Cu = 29 - Zn = 30 - Ga = 31 - Ge = 32 - As = 33 - Se = 34 - Br = 35 - Kr = 36 - - @classmethod - def from_str(cls, symbol: str) -> "Element": - """Parse element from string symbol.""" - # First try the symbol as-is (for proper case like "Fe") - try: - return cls[symbol] - except KeyError: - pass - - # Try uppercase (for "FE" -> "Fe") - symbol_upper = symbol.upper() - try: - # Check all enum members for case-insensitive match - for member in cls: - if member.name.upper() == symbol_upper: - return member - except Exception: - pass - - # Try common variations - if symbol_upper in ["D"]: # Deuterium -> Hydrogen - return cls.H - - raise ValueError(f"Unknown element symbol: {symbol}") - - def __str__(self) -> str: - return self.name - - -@total_ordering -class AtomRef: - """Reference to an atom by index.""" - - def __init__(self, value: int): - if value < 0: - raise ValueError("Atom index must be non-negative") - self.value = value - - def __eq__(self, other): - return isinstance(other, AtomRef) and self.value == other.value - - def __lt__(self, other): - return isinstance(other, AtomRef) and self.value < other.value - - def __hash__(self): - return hash(self.value) - - def __repr__(self): - return f"AtomRef({self.value})" - - def __int__(self): - return self.value - - -@total_ordering -class FragmentRef: - """Reference to a fragment by index.""" - - def __init__(self, value: int): - if value < 0: - raise ValueError("Fragment index must be non-negative") - self.value = value - - def __eq__(self, other): - return isinstance(other, FragmentRef) and self.value == other.value - - def __lt__(self, other): - return isinstance(other, FragmentRef) and self.value < other.value - - def __hash__(self): - return hash(self.value) - - def __repr__(self): - return f"FragmentRef({self.value})" - - def __int__(self): - return self.value - - -@total_ordering -class ResidueRef: - """Reference to a residue by index.""" - - def __init__(self, value: int): - if value < 0: - raise ValueError("Residue index must be non-negative") - self.value = value - - def __eq__(self, other): - return isinstance(other, ResidueRef) and self.value == other.value - - def __lt__(self, other): - return isinstance(other, ResidueRef) and self.value < other.value - - def __hash__(self): - return hash(self.value) - - def __repr__(self): - return f"ResidueRef({self.value})" - - def __int__(self): - return self.value - - -@total_ordering -class ChainRef: - """Reference to a chain by index.""" - - def __init__(self, value: int): - if value < 0: - raise ValueError("Chain index must be non-negative") - self.value = value - - def __eq__(self, other): - return isinstance(other, ChainRef) and self.value == other.value - - def __lt__(self, other): - return isinstance(other, ChainRef) and self.value < other.value - - def __hash__(self): - return hash(self.value) - - def __repr__(self): - return f"ChainRef({self.value})" - - def __int__(self): - return self.value - - -@dataclass -class FormalCharge: - """Formal charge of an atom.""" - - charge: int - - def __repr__(self): - return f"FormalCharge({self.charge})" - - def __int__(self): - return self.charge - - -@dataclass -class PartialCharge: - """Partial charge of an atom.""" - - charge: float - - def __repr__(self): - return f"PartialCharge({self.charge})" - - def __float__(self): - return self.charge - - -class BondOrder(IntEnum): - """Bond order enum.""" - - Single = 1 - Double = 2 - Triple = 3 - OneAndAHalf = 4 # Partial bond (e.g. amide bond) - Ring = 5 # Aromatic - - -@dataclass -class Bond: - """Bond between two atoms.""" - - atom1: AtomRef - atom2: AtomRef - order: BondOrder - - def __post_init__(self): - if self.atom1.value == self.atom2.value: - raise ValueError("Bond cannot connect an atom to itself") - - -class Fragment: - """Fragment containing a list of atoms.""" - - def __init__(self, atoms: list[AtomRef] | list[int] | None = None): - # Store as list of integers to match JSON serialization - if atoms is None: - self.atoms = [] - else: - self.atoms = [ - atom.value if isinstance(atom, AtomRef) else atom for atom in atoms - ] - - def __len__(self) -> int: - return len(self.atoms) - - def __iter__(self): - # Return AtomRef objects when iterating - return (AtomRef(atom) for atom in self.atoms) - - def __eq__(self, other): - return isinstance(other, Fragment) and self.atoms == other.atoms - - def __repr__(self): - return f"Fragment({[AtomRef(a) for a in self.atoms]})" - - -class SchemaVersion(Enum): - """Schema version for the topology format.""" - - V1 = "v1" - V2 = "v2" - - -@dataclass -class Topology: - """Topology contains all atom information.""" - - schema_version: SchemaVersion = SchemaVersion.V2 - - # Element of each atom - symbols: list[Element] = field(default_factory=list) - - # XYZ coordinates of each atom (3 * len(symbols)) - geometry: list[float] = field(default_factory=list) - - # Optional atom labels - labels: list[str] | None = None - - # Optional formal charges - formal_charges: list[FormalCharge] | None = None - - # Optional partial charges - partial_charges: list[PartialCharge] | None = None - - # Optional connectivity - connectivity: list[Bond] | None = None - - # Optional velocities (3 * len(symbols)) - velocities: list[float] | None = None - - # Optional fragments - fragments: list[Fragment] | None = None - - # Optional fragment charges - fragment_formal_charges: list[FormalCharge] | None = None - fragment_partial_charges: list[PartialCharge] | None = None - - @staticmethod - def from_json(json_content: str | Path | dict) -> "Topology": - if isinstance(json_content, str): - topology_data = json.loads(json_content) - elif isinstance(json_content, Path): - with json_content.open() as f: - topology_data = json.load(f) - elif isinstance(json_content, dict): - topology_data = json_content - else: - print( - "WARNING: Tried to load Topology from JSON but " - "it wasn't a str, Path, or dict!", - file=sys.stderr, - ) - topology_data = json_content - - topology = Topology() - - # Default, could parse from schema_version - topology.schema_version = SchemaVersion.V2 - - topology.symbols = [Element.from_str(s) for s in topology_data["symbols"]] - topology.geometry = topology_data["geometry"] - - if "labels" in topology_data and topology_data["labels"]: - topology.labels = topology_data["labels"] - - if "formal_charges" in topology_data and topology_data["formal_charges"]: - topology.formal_charges = [ - FormalCharge(c) for c in topology_data["formal_charges"] - ] - - if "partial_charges" in topology_data and topology_data["partial_charges"]: - topology.partial_charges = [ - PartialCharge(c) for c in topology_data["partial_charges"] - ] - - if "velocities" in topology_data and topology_data["velocities"]: - topology.velocities = topology_data["velocities"] - - if "connectivity" in topology_data and topology_data["connectivity"]: - # Connectivity is a list of [atom1, atom2, bond_order] - # BondOrder enum: 1=Single, 2=Double, 3=Triple, 4=OneAndAHalf (partial), 5=Ring (aromatic) - bonds = [] - for bond_data in topology_data["connectivity"]: - if isinstance(bond_data, list) and len(bond_data) >= 2: - atom1_idx = bond_data[0] - atom2_idx = bond_data[1] - bond_order_val = bond_data[2] - - # Support old version mapping: 254 -> 4 (OneAndAHalf/partial), 255 -> 5 (Ring/aromatic) - if bond_order_val == 254: - bond_order_val = 4 - elif bond_order_val == 255: - bond_order_val = 5 - - bond_order = BondOrder(bond_order_val) - bonds.append( - Bond(AtomRef(atom1_idx), AtomRef(atom2_idx), bond_order) - ) - topology.connectivity = bonds - - if "fragments" in topology_data and topology_data["fragments"]: - topology.fragments = [Fragment(frag) for frag in topology_data["fragments"]] - - if ( - "fragment_formal_charges" in topology_data - and topology_data["fragment_formal_charges"] - ): - topology.fragment_formal_charges = [ - FormalCharge(c) for c in topology_data["fragment_formal_charges"] - ] - - if ( - "fragment_partial_charges" in topology_data - and topology_data["fragment_partial_charges"] - ): - topology.fragment_partial_charges = [ - PartialCharge(c) for c in topology_data["fragment_partial_charges"] - ] - - return topology - - def to_dict(self) -> dict[str, object]: - topology_dict: dict[str, object] = { - "schema_version": "0.2.0", - "symbols": [str(symbol) for symbol in self.symbols], - "geometry": self.geometry, - } - - if self.labels is not None: - topology_dict["labels"] = self.labels - - if self.formal_charges is not None: - topology_dict["formal_charges"] = [c.charge for c in self.formal_charges] - - if self.partial_charges is not None: - topology_dict["partial_charges"] = [c.charge for c in self.partial_charges] - - if self.connectivity is not None: - topology_dict["connectivity"] = [ - [bond.atom1.value, bond.atom2.value, bond.order.value] - for bond in self.connectivity - ] - else: - topology_dict["connectivity"] = [] - - if self.velocities is not None: - topology_dict["velocities"] = self.velocities - - if self.fragments is not None: - topology_dict["fragments"] = [fragment.atoms for fragment in self.fragments] - else: - topology_dict["fragments"] = [] - - if self.fragment_formal_charges is not None: - topology_dict["fragment_formal_charges"] = [ - c.charge for c in self.fragment_formal_charges - ] - else: - topology_dict["fragment_formal_charges"] = [] - - if self.fragment_partial_charges is not None: - topology_dict["fragment_partial_charges"] = [ - c.charge for c in self.fragment_partial_charges - ] - - return topology_dict - - def check(self) -> None: - """Validate the topology structure.""" - # Check geometry length - if len(self.geometry) != len(self.symbols) * 3: - raise ValueError( - f"Geometry length {len(self.geometry)} != symbols length {len(self.symbols)} * 3" - ) - - # Check optional field lengths - if self.labels is not None and len(self.labels) != len(self.symbols): - raise ValueError( - f"Labels length {len(self.labels)} != symbols length {len(self.symbols)}" - ) - - if self.partial_charges is not None and len(self.partial_charges) != len( - self.symbols - ): - raise ValueError( - f"Partial charges length {len(self.partial_charges)} != symbols length {len(self.symbols)}" - ) - - if self.formal_charges is not None and len(self.formal_charges) != len( - self.symbols - ): - raise ValueError( - f"Formal charges length {len(self.formal_charges)} != symbols length {len(self.symbols)}" - ) - - if ( - self.velocities is not None - and len(self.velocities) != len(self.symbols) * 3 - ): - raise ValueError( - f"Velocities length {len(self.velocities)} != symbols length {len(self.symbols)} * 3" - ) - - # Check connectivity - if self.connectivity is not None: - for bond in self.connectivity: - if bond.atom1.value >= len(self.symbols) or bond.atom2.value >= len( - self.symbols - ): - raise ValueError( - f"Bond references invalid atom indices: {bond.atom1.value}, {bond.atom2.value}" - ) - - # Check fragments - if self.fragments is not None: - atom_set = set() - for fragment in self.fragments: - for atom_idx in fragment.atoms: - if atom_idx >= len(self.symbols): - raise ValueError( - f"Fragment references invalid atom index: {atom_idx}" - ) - if atom_idx in atom_set: - raise ValueError( - f"Atom {atom_idx} appears in multiple fragments" - ) - atom_set.add(atom_idx) +Topology + Per-atom information: element symbols, XYZ geometry (flat list, + 3 * n_atoms), optional atom labels, formal/partial charges, bond + connectivity, velocities, and fragment assignments. - if len(atom_set) != len(self.symbols): - raise ValueError("Not all atoms are assigned to fragments") +Residues + Residue groupings over atoms -- sequence names (e.g. amino-acid + three-letter codes), sequence numbers, insertion codes, and the + mapping of which atoms belong to which residue. - def distance_between_atoms(self, atom1: AtomRef, atom2: AtomRef) -> float: - """Calculate distance between two atoms.""" - if atom1.value >= len(self.symbols) or atom2.value >= len(self.symbols): - raise ValueError("Invalid atom indices") +Chains + Chain groupings over residues, plus optional secondary-structure + annotations (alpha helices and beta sheets). - i1, i2 = atom1.value * 3, atom2.value * 3 - dx = self.geometry[i1] - self.geometry[i2] - dy = self.geometry[i1 + 1] - self.geometry[i2 + 1] - dz = self.geometry[i1 + 2] - self.geometry[i2 + 2] +Element & bond types +-------------------- +Element + Chemical element enum (H, He, Li, ..., Kr). Integer-valued, + matching atomic number. - return (dx * dx + dy * dy + dz * dz) ** 0.5 +Bond + A bond between two atoms (atom indices + bond order). - def distance_to_point( - self, atom: AtomRef, point: tuple[float, float, float] - ) -> float: - """Calculate distance from atom to a point.""" - if atom.value >= len(self.symbols): - raise ValueError("Invalid atom index") +BondOrder + Bond order enum: Single, Double, Triple, OneAndAHalf (partial / + amide), Ring (aromatic). - i = atom.value * 3 - dx = self.geometry[i] - point[0] - dy = self.geometry[i + 1] - point[1] - dz = self.geometry[i + 2] - point[2] +Stereochemistry + Atom stereochemistry descriptor (R/S chirality, E/Z geometry, etc.). - return (dx * dx + dy * dy + dz * dz) ** 0.5 +Secondary structure +------------------- +HelixClass + PDB helix classification (right-handed alpha, 3-10, pi, etc.). - def get_atoms_near_point( - self, - point: tuple[float, float, float], - threshold: float, - atom_indices: list[int] | None = None, - ) -> list[int]: - """Get atom indices within threshold distance of a point.""" - if atom_indices is None: - atom_indices = list(range(len(self.symbols))) +StrandSense + Parallel vs. anti-parallel strand orientation in a beta sheet. - near_atoms = [] - for atom_idx in atom_indices: - if atom_idx >= len(self.symbols): - continue +AlphaHelices + Collection of alpha-helix annotations for a structure. - distance = self.distance_to_point(AtomRef(atom_idx), point) - if distance <= threshold: - near_atoms.append(atom_idx) +BetaSheets + Collection of beta-sheet annotations for a structure. - return near_atoms +Reference / index types +----------------------- +AtomRef + ``NewType`` over ``int`` -- a zero-based atom index. - def get_fragments_near_fragment( - self, - frag_idx: int, - threshold: float, - atom_indices: list[int] | None = None, - ) -> list[FragmentRef]: - """Get fragment indices within threshold distance of another fragment.""" - if not self.fragments: - return [] +ResidueRef + ``NewType`` over ``int`` -- a zero-based residue index. - if atom_indices is None: - atom_indices = list(range(len(self.symbols))) +ChainRef + ``NewType`` over ``int`` -- a zero-based chain index. - near_atoms = set() - for atom_idx in self.fragments[frag_idx]: - atom_idx = int(atom_idx) - if atom_idx >= len(self.symbols): - print("Warning: bad atom index {atom_index}", file=sys.stderr) - continue +FragmentRef + ``NewType`` over ``int`` -- a zero-based fragment index. - near_atoms |= { - AtomRef(a) - for a in self.get_atoms_near_point( - ( - self.geometry[atom_idx * 3], - self.geometry[atom_idx * 3 + 1], - self.geometry[atom_idx * 3 + 2], - ), - threshold, - ) - } +These are plain ``int`` at runtime but provide static-analysis +distinctness so that, e.g., an ``AtomRef`` is not accidentally used +where a ``ResidueRef`` is expected. - return [ - FragmentRef(i) - for (i, f) in enumerate(self.fragments) - if (i != frag_idx and not near_atoms.isdisjoint(f)) - ] +Quick examples +-------------- +Loading a structure and inspecting it:: - def extend(self, other: Self) -> None: - """Extend this topology with atoms from another topology.""" - offset = len(self.symbols) + from rush.convert import load_structure - # Extend basic arrays - self.symbols.extend(other.symbols) - self.geometry.extend(other.geometry) + trc = load_structure("1crn.pdb") + print(len(trc.topology.symbols)) # number of atoms + print(trc.residues.seqs[:5]) # first five residue names + print(len(trc.chains.chains)) # number of chains - # Extend optional arrays - if self.labels is not None and other.labels is not None: - self.labels.extend(other.labels) - elif self.labels is not None and other.labels is None: - self.labels.extend([""] * len(other.symbols)) +Converting to/from JSON dicts:: - if self.partial_charges is not None and other.partial_charges is not None: - self.partial_charges.extend(other.partial_charges) - elif self.partial_charges is not None and other.partial_charges is None: - self.partial_charges.extend([PartialCharge(0.0)] * len(other.symbols)) + d = trc.to_dict() # -> dict with topology/residues/chains + trc2 = TRC.from_dict(d) # round-trip back to TRC - if self.formal_charges is not None and other.formal_charges is not None: - self.formal_charges.extend(other.formal_charges) - elif self.formal_charges is not None and other.formal_charges is None: - self.formal_charges.extend([FormalCharge(0)] * len(other.symbols)) +Validation:: - if self.velocities is not None and other.velocities is not None: - self.velocities.extend(other.velocities) - elif self.velocities is not None and other.velocities is None: - self.velocities.extend([0.0] * (len(other.symbols) * 3)) - - # Update connectivity with offset - if other.connectivity is not None: - if self.connectivity is None: - self.connectivity = [] - for bond in other.connectivity: - new_bond = Bond( - AtomRef(bond.atom1.value + offset), - AtomRef(bond.atom2.value + offset), - bond.order, - ) - self.connectivity.append(new_bond) - - # Update fragments with offset - if self.fragments is not None and other.fragments is not None: - for fragment in other.fragments: - new_atoms = [AtomRef(atom + offset) for atom in fragment.atoms] - self.fragments.append(Fragment(new_atoms)) - elif self.fragments is not None and other.fragments is None: - # Create a single fragment for all new atoms - new_atoms = [AtomRef(i + offset) for i in range(len(other.symbols))] - self.fragments.append(Fragment(new_atoms)) - - # Extend fragment charges - if other.fragment_formal_charges is not None: - if self.fragment_formal_charges is None: - self.fragment_formal_charges = [] - self.fragment_formal_charges.extend(other.fragment_formal_charges) - - if other.fragment_partial_charges is not None: - if self.fragment_partial_charges is None: - self.fragment_partial_charges = [] - self.fragment_partial_charges.extend(other.fragment_partial_charges) - - def new_topology_from_residue_subset( - self, residue_subset: list["Residue"] - ) -> "Topology": - """Create a new topology containing only atoms from specified residues.""" - new_topology = Topology(schema_version=self.schema_version) - - # Collect all atom indices from residues - atom_indices = [] - for residue in residue_subset: - atom_indices.extend(residue.atoms) # Already integers - - # Build atom mapping - old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(atom_indices)} - - # Copy basic data - new_topology.symbols = [self.symbols[i] for i in atom_indices] - new_topology.geometry = [] - for i in atom_indices: - new_topology.geometry.extend(self.geometry[i * 3 : (i + 1) * 3]) - - # Copy optional data - if self.labels: - new_topology.labels = [self.labels[i] for i in atom_indices] - - if self.partial_charges: - new_topology.partial_charges = [ - self.partial_charges[i] for i in atom_indices - ] - - if self.formal_charges: - new_topology.formal_charges = [self.formal_charges[i] for i in atom_indices] - - if self.velocities: - new_topology.velocities = [] - for i in atom_indices: - new_topology.velocities.extend(self.velocities[i * 3 : (i + 1) * 3]) - - # Copy connectivity (only bonds between atoms in subset) - if self.connectivity: - new_topology.connectivity = [] - for bond in self.connectivity: - if bond.atom1.value in old_to_new and bond.atom2.value in old_to_new: - new_bond = Bond( - AtomRef(old_to_new[bond.atom1.value]), - AtomRef(old_to_new[bond.atom2.value]), - bond.order, - ) - new_topology.connectivity.append(new_bond) + trc.check() # raises on inconsistent data +""" - return new_topology +import sys +from enum import Enum +from typing import NewType + +import libqdx + +# Ref types — distinct for type checking, plain int at runtime +AtomRef = NewType("AtomRef", int) +ResidueRef = NewType("ResidueRef", int) +ChainRef = NewType("ChainRef", int) +FragmentRef = NewType("FragmentRef", int) + +# Re-export native types +TRC = libqdx.PyTRC +Topology = libqdx.PyTopology +Residues = libqdx.PyResidues +Chains = libqdx.PyChains +Element = libqdx.Element +Bond = libqdx.Bond +BondOrder = libqdx.BondOrder +Stereochemistry = libqdx.Stereochemistry +HelixClass = libqdx.HelixClass +StrandSense = libqdx.StrandSense +AlphaHelices = libqdx.AlphaHelices +BetaSheets = libqdx.BetaSheets +AtomCheckStrictness = libqdx.AtomCheckStrictness class AminoAcidSeq(Enum): @@ -754,494 +176,131 @@ def is_amino_acid(cls, residue_name: str) -> bool: except ValueError: return False + _SINGLE_LETTER = { + "GLY": "G", + "ALA": "A", + "VAL": "V", + "LEU": "L", + "ILE": "I", + "PRO": "P", + "SER": "S", + "THR": "T", + "ASN": "N", + "GLN": "Q", + "CYS": "C", + "CYD": "C", + "CYX": "C", + "MET": "M", + "PHE": "F", + "TYR": "Y", + "TYD": "Y", + "TRP": "W", + "ASP": "D", + "ASH": "D", + "GLU": "E", + "GLH": "E", + "HIS": "H", + "HIN": "H", + "HID": "H", + "HIE": "H", + "HIP": "H", + "LYS": "K", + "LYD": "K", + "LYN": "K", + "ARG": "R", + "HYP": "O", + } + def to_single_letter(self) -> str: """Convert to single letter code.""" - mapping = { - "GLY": "G", - "ALA": "A", - "VAL": "V", - "LEU": "L", - "ILE": "I", - "PRO": "P", - "SER": "S", - "THR": "T", - "ASN": "N", - "GLN": "Q", - "CYS": "C", - "CYD": "C", - "CYX": "C", - "MET": "M", - "PHE": "F", - "TYR": "Y", - "TYD": "Y", - "TRP": "W", - "ASP": "D", - "ASH": "D", - "GLU": "E", - "GLH": "E", - "HIS": "H", - "HIN": "H", - "HID": "H", - "HIE": "H", - "HIP": "H", - "LYS": "K", - "LYD": "K", - "LYN": "K", - "ARG": "R", - "HYP": "O", - } - return mapping.get(self.value, "X") - - -class Residue: - """A residue containing a list of atoms.""" - - def __init__(self, atoms: list[AtomRef] | list[int] | None = None): - # Store as list of integers to match JSON serialization - if atoms is None: - self.atoms = [] - else: - self.atoms = [ - atom.value if isinstance(atom, AtomRef) else atom for atom in atoms - ] - - def __len__(self) -> int: - return len(self.atoms) - - def __iter__(self): - # Return AtomRef objects when iterating - return (AtomRef(atom) for atom in self.atoms) - - def contains(self, atom: AtomRef) -> bool: - return atom.value in self.atoms - - def __eq__(self, other): - return isinstance(other, Residue) and self.atoms == other.atoms - - def __repr__(self): - return f"Residue({[AtomRef(a) for a in self.atoms]})" - - -@dataclass -class Residues: - """Collection of residues with metadata.""" - - # List of residues - residues: list[Residue] = field(default_factory=list) - - # Sequence names (e.g., amino acid names) - seqs: list[str] = field(default_factory=list) - - # Sequence numbers - seq_ns: list[int] = field(default_factory=list) - - # Insertion codes - insertion_codes: list[str] = field(default_factory=list) - - # WARN: Deprecated - labeled: list[ResidueRef] | None = None - - # WARN: Deprecated - labels: list[list[str]] | None = None - - @staticmethod - def from_json(json_content: str | Path | dict) -> "Residues": - if isinstance(json_content, str): - residues_data = json.loads(json_content) - elif isinstance(json_content, Path): - with json_content.open() as f: - residues_data = json.load(f) - elif isinstance(json_content, dict): - residues_data = json_content - else: - print( - "WARNING: Tried to load Residues from JSON but " - "it wasn't a str, Path, or dict!", - file=sys.stderr, - ) - residues_data = json_content - - residues = Residues() - residues.residues = [Residue(res) for res in residues_data["residues"]] - residues.seqs = residues_data["seqs"] - residues.seq_ns = residues_data["seq_ns"] - residues.insertion_codes = residues_data["insertion_codes"] - - if residues_data.get("labeled"): - residues.labeled = [ResidueRef(r) for r in residues_data["labeled"]] - - if residues_data.get("labels"): - residues.labels = residues_data["labels"] - - return residues - - def to_dict(self) -> dict[str, object]: - residues_dict: dict[str, object] = { - "residues": [residue.atoms for residue in self.residues], - "seqs": self.seqs, - "seq_ns": self.seq_ns, - "insertion_codes": self.insertion_codes, - } - - if self.labeled is not None: - residues_dict["labeled"] = [r.value for r in self.labeled] - - if self.labels is not None: - residues_dict["labels"] = self.labels - - return residues_dict - - def check(self) -> None: - """Validate the residues structure.""" - if len(self.seqs) != len(self.residues): - raise ValueError( - f"Seqs length {len(self.seqs)} != residues length {len(self.residues)}" - ) - - if len(self.seq_ns) != len(self.residues): - raise ValueError( - f"Seq_ns length {len(self.seq_ns)} != residues length {len(self.residues)}" - ) - - if len(self.insertion_codes) != len(self.residues): - raise ValueError( - f"Insertion codes length {len(self.insertion_codes)} != residues length {len(self.residues)}" - ) - - def is_amino_acid(self, index: int) -> bool: - """Check if residue at index is an amino acid.""" - if index >= len(self.seqs): - return False - return AminoAcidSeq.is_amino_acid(self.seqs[index]) - - def amino_acid_indices(self) -> list[int]: - """Get indices of amino acid residues.""" - return [i for i in range(len(self.seqs)) if self.is_amino_acid(i)] - - def non_amino_acid_indices(self) -> list[int]: - """Get indices of non-amino acid residues.""" - return [i for i in range(len(self.seqs)) if not self.is_amino_acid(i)] - - def extend(self, other: Self) -> None: - """Extend this residues collection with another.""" - # Calculate atom offset for renumbering - offset = sum(len(residue.atoms) for residue in self.residues) - - # Calculate residue offset before extending (number of residues in self before merge) - residue_offset = len(self.residues) - - # Extend residues with renumbered atoms - for residue in other.residues: - new_atoms = [atom + offset for atom in residue.atoms] - self.residues.append(Residue(new_atoms)) - - # Extend metadata - self.seqs.extend(other.seqs) - self.seq_ns.extend(other.seq_ns) - self.insertion_codes.extend(other.insertion_codes) - - # Handle labeled residues and labels - if other.labeled is not None: - if self.labeled is None: - self.labeled = [] - # Renumber residue references - for ref in other.labeled: - if isinstance(ref, ResidueRef): - self.labeled.append(ResidueRef(ref.value + residue_offset)) - elif isinstance(ref, int): - self.labeled.append(ref + residue_offset) - else: - self.labeled.append(ref) - - if other.labels is not None: - if self.labels is None: - self.labels = [] - # Copy labels (they're lists, so we need to copy them) - for label in other.labels: - if isinstance(label, list): - self.labels.append(label.copy()) - else: - self.labels.append(label) - - def new_residues_from_subset(self, residue_refs: list[ResidueRef]) -> "Residues": - """Create new residues collection from a subset of residue references.""" - new_residues = Residues() - - offset = 0 - for residue_ref in residue_refs: - if residue_ref.value >= len(self.residues): - continue - - # Get original residue - original_residue = self.residues[residue_ref.value] - residue_len = len(original_residue.atoms) - - # Create new residue with renumbered atoms - new_atoms = [offset + i for i in range(residue_len)] - new_residues.residues.append(Residue(new_atoms)) - - # Copy metadata - new_residues.seqs.append(self.seqs[residue_ref.value]) - new_residues.seq_ns.append(self.seq_ns[residue_ref.value]) - new_residues.insertion_codes.append(self.insertion_codes[residue_ref.value]) - - offset += residue_len - - return new_residues - - -class Chain: - """A chain containing a list of residues.""" - - def __init__(self, residues: list[ResidueRef] | list[int] | None = None): - # Store as list of integers to match JSON serialization - if residues is None: - self.residues = [] - else: - self.residues = [ - res.value if isinstance(res, ResidueRef) else res for res in residues - ] - - def __len__(self) -> int: - return len(self.residues) - - def __iter__(self): - # Return ResidueRef objects when iterating - return (ResidueRef(res) for res in self.residues) - - def contains(self, residue: ResidueRef) -> bool: - return residue.value in self.residues - - def __eq__(self, other): - return isinstance(other, Chain) and self.residues == other.residues - - def __repr__(self): - return f"Chain({[ResidueRef(r) for r in self.residues]})" - - -@dataclass -class Chains: - """Collection of chains with secondary structure information.""" - - # List of chains - chains: list[Chain] = field(default_factory=list) - - # Optional alpha helix residues - alpha_helices: list[ResidueRef] | None = None - - # Optional beta sheet residues - beta_sheets: list[ResidueRef] | None = None - - # WARN: Deprecated - labeled: list[ChainRef] | None = None - - # WARN: Deprecated - labels: list[list[str]] | None = None - - @staticmethod - def from_json(json_content: str | Path | dict) -> "Chains": - if isinstance(json_content, str): - chains_data = json.loads(json_content) - elif isinstance(json_content, Path): - with json_content.open() as f: - chains_data = json.load(f) - elif isinstance(json_content, dict): - chains_data = json_content - else: - print( - "WARNING: Tried to load Chains from JSON but " - "it wasn't a str, Path, or dict!", - file=sys.stderr, + return self._SINGLE_LETTER.get(self.value, "X") + + +def distance_between_atoms(topology: Topology, atom1: AtomRef, atom2: AtomRef) -> float: + """Calculate distance between two atoms.""" + if atom1 >= len(topology.symbols) or atom2 >= len(topology.symbols): + raise ValueError("Invalid atom indices") + + i1, i2 = atom1 * 3, atom2 * 3 + dx = topology.geometry[i1] - topology.geometry[i2] + dy = topology.geometry[i1 + 1] - topology.geometry[i2 + 1] + dz = topology.geometry[i1 + 2] - topology.geometry[i2 + 2] + + return (dx * dx + dy * dy + dz * dz) ** 0.5 + + +def distance_to_point( + topology: Topology, atom: AtomRef, point: tuple[float, float, float] +) -> float: + """Calculate distance from atom to a point.""" + if atom >= len(topology.symbols): + raise ValueError("Invalid atom index") + + i = atom * 3 + dx = topology.geometry[i] - point[0] + dy = topology.geometry[i + 1] - point[1] + dz = topology.geometry[i + 2] - point[2] + + return (dx * dx + dy * dy + dz * dz) ** 0.5 + + +def get_atoms_near_point( + topology: Topology, + point: tuple[float, float, float], + threshold: float, + atom_indices: list[int] | None = None, +) -> list[int]: + """Get atom indices within threshold distance of a point.""" + if atom_indices is None: + atom_indices = list(range(len(topology.symbols))) + + near_atoms = [] + for atom_idx in atom_indices: + if atom_idx >= len(topology.symbols): + continue + + distance = distance_to_point(topology, AtomRef(atom_idx), point) + if distance <= threshold: + near_atoms.append(atom_idx) + + return near_atoms + + +def get_fragments_near_fragment( + topology: Topology, + frag_idx: int, + threshold: float, + atom_indices: list[int] | None = None, +) -> list[FragmentRef]: + """Get fragment indices within threshold distance of another fragment.""" + if not topology.fragments: + return [] + + if atom_indices is None: + atom_indices = list(range(len(topology.symbols))) + + near_atoms = set() + for atom_idx in topology.fragments[frag_idx]: + atom_idx = int(atom_idx) + if atom_idx >= len(topology.symbols): + print("Warning: bad atom index {atom_index}", file=sys.stderr) + continue + + near_atoms |= { + AtomRef(a) + for a in get_atoms_near_point( + topology, + ( + topology.geometry[atom_idx * 3], + topology.geometry[atom_idx * 3 + 1], + topology.geometry[atom_idx * 3 + 2], + ), + threshold, ) - chains_data = json_content - - chains = Chains() - - chains.chains = [Chain(chain) for chain in chains_data["chains"]] - - if chains_data.get("alpha_helices"): - chains.alpha_helices = [ResidueRef(r) for r in chains_data["alpha_helices"]] - - if chains_data.get("beta_sheets"): - chains.beta_sheets = [ResidueRef(r) for r in chains_data["beta_sheets"]] - - if chains_data.get("labeled"): - chains.labeled = [ChainRef(c) for c in chains_data["labeled"]] - - if chains_data.get("labels"): - chains.labels = chains_data["labels"] - - return chains - - def to_dict(self) -> dict[str, object]: - chains_dict: dict[str, object] = { - "chains": [chain.residues for chain in self.chains], } - if self.alpha_helices is not None: - chains_dict["alpha_helices"] = [r.value for r in self.alpha_helices] - - if self.beta_sheets is not None: - chains_dict["beta_sheets"] = [r.value for r in self.beta_sheets] - - if self.labeled is not None: - chains_dict["labeled"] = [r.value for r in self.labeled] - - if self.labels is not None: - chains_dict["labels"] = self.labels - - return chains_dict - - def check(self) -> None: - """Validate the chains structure.""" - # Basic validation - more complex checks could be added - pass - - def extend(self, other: Self) -> None: - """Extend this chains collection with another.""" - # Calculate residue offset - residue_offset = sum(len(chain.residues) for chain in self.chains) - - # Extend chains with renumbered residue references - for chain in other.chains: - new_residue_refs = [ref + residue_offset for ref in chain.residues] - self.chains.append(Chain(new_residue_refs)) - - # Extend secondary structure info - if self.alpha_helices is not None and other.alpha_helices is not None: - new_alpha_helices = [ - ref.value + residue_offset for ref in other.alpha_helices - ] - self.alpha_helices.extend([ResidueRef(ref) for ref in new_alpha_helices]) - - if self.beta_sheets is not None and other.beta_sheets is not None: - new_beta_sheets = [ref.value + residue_offset for ref in other.beta_sheets] - self.beta_sheets.extend([ResidueRef(ref) for ref in new_beta_sheets]) - - def new_chains_from_residue_subset( - self, residue_refs: list[ResidueRef] - ) -> "Chains": - """Create new chains collection from a subset of residue references.""" - new_chains = Chains() - - # Create mapping from old residue indices to new ones - old_to_new_residue = {ref.value: i for i, ref in enumerate(residue_refs)} - - # Group residues by their original chains - chain_to_new_residues = defaultdict(list) - - for new_idx, residue_ref in enumerate(residue_refs): - # Find which chain this residue belonged to - for chain_idx, chain in enumerate(self.chains): - if residue_ref.value in chain.residues: - chain_to_new_residues[chain_idx].append(new_idx) - break - - # Create new chains - for chain_idx in sorted(chain_to_new_residues.keys()): - new_chain_residues = chain_to_new_residues[chain_idx] - # Sort by original sequence order - original_chain = self.chains[chain_idx] - new_chain_residues.sort( - key=lambda new_idx: original_chain.residues.index( - residue_refs[new_idx].value - ) - ) - new_chains.chains.append(Chain(new_chain_residues)) - - # Filter secondary structure info - if self.alpha_helices: - new_alpha_helices = [] - for residue_ref in self.alpha_helices: - if residue_ref.value in old_to_new_residue: - new_alpha_helices.append( - ResidueRef(old_to_new_residue[residue_ref.value]) - ) - new_chains.alpha_helices = new_alpha_helices if new_alpha_helices else None - - if self.beta_sheets: - new_beta_sheets = [] - for residue_ref in self.beta_sheets: - if residue_ref.value in old_to_new_residue: - new_beta_sheets.append( - ResidueRef(old_to_new_residue[residue_ref.value]) - ) - new_chains.beta_sheets = new_beta_sheets if new_beta_sheets else None - - return new_chains - - -@dataclass -class TRC: - """ - Combined Topology, Residues, and Chains structure. - This is the main structure for representing molecular systems on the Rush platform. - """ - - topology: Topology = field(default_factory=Topology) - residues: Residues = field(default_factory=Residues) - chains: Chains = field(default_factory=Chains) - - def check(self) -> None: - """Validate the entire TRC structure.""" - self.topology.check() - self.residues.check() - self.chains.check() - - # Check that all atoms are in residues - atom_set = set() - for residue in self.residues.residues: - for atom_idx in residue.atoms: - if atom_idx in atom_set: - raise ValueError(f"Atom {atom_idx} appears in multiple residues") - atom_set.add(atom_idx) - - if len(atom_set) != len(self.topology.symbols): - raise ValueError("Not all atoms are assigned to residues") - - # Check that all residues are in chains - residue_set = set() - for chain in self.chains.chains: - for residue_idx in chain.residues: - if residue_idx >= len(self.residues.residues): - raise ValueError( - f"Chain references invalid residue index: {residue_idx}" - ) - if residue_idx in residue_set: - raise ValueError( - f"Residue {residue_idx} appears in multiple chains" - ) - residue_set.add(residue_idx) - - if len(residue_set) != len(self.residues.residues): - raise ValueError("Not all residues are assigned to chains") - - def extend(self, other: Self) -> None: - """Extend this TRC with another TRC.""" - self.topology.extend(other.topology) - self.residues.extend(other.residues) - self.chains.extend(other.chains) - - def new_trc_from_residue_subset(self, residue_refs: list[ResidueRef]) -> "TRC": - """Create new TRC from a subset of residue references.""" - # Get residue subset - residue_subset = [self.residues.residues[ref.value] for ref in residue_refs] - - return TRC( - topology=self.topology.new_topology_from_residue_subset(residue_subset), - residues=self.residues.new_residues_from_subset(residue_refs), - chains=self.chains.new_chains_from_residue_subset(residue_refs), - ) - - -@dataclass(frozen=True) -class ResidueId: - """Unique identifier for a residue.""" - - chain_id: str - sequence_number: int - insertion_code: str - residue_name: str - - def __str__(self) -> str: - return f"{self.chain_id}_{self.sequence_number:>9}_{self.insertion_code}_{self.residue_name}" + return [ + FragmentRef(i) + for (i, f) in enumerate(topology.fragments) + if (i != frag_idx and not near_atoms.isdisjoint(f)) + ] diff --git a/src/rush/objects.py b/src/rush/objects.py index 1123496..2d64445 100644 --- a/src/rush/objects.py +++ b/src/rush/objects.py @@ -140,10 +140,11 @@ class TRCRef: @classmethod def upload(cls, trc: TRC) -> Self: + d = trc.to_dict() return cls( - RushObject.from_dict(upload_object(trc.topology.to_dict())), - RushObject.from_dict(upload_object(trc.residues.to_dict())), - RushObject.from_dict(upload_object(trc.chains.to_dict())), + RushObject.from_dict(upload_object(d["topology"])), + RushObject.from_dict(upload_object(d["residues"])), + RushObject.from_dict(upload_object(d["chains"])), ) def fetch(self) -> TRC: diff --git a/tests/module_output_helpers/test_exess_output_helpers.py b/tests/module_output_helpers/test_exess_output_helpers.py index b9f7366..9c294d2 100644 --- a/tests/module_output_helpers/test_exess_output_helpers.py +++ b/tests/module_output_helpers/test_exess_output_helpers.py @@ -135,17 +135,13 @@ def test_optimization_ref_fetch(monkeypatch): assert isinstance(result, OptimizationResult) assert len(result.trajectory) == 1 - assert result.trajectory[0].geometry == [ - 0.0, - 0.0, - 0.0, - 0.7, - 0.5, - 0.0, - -0.7, - 0.5, - 0.0, - ] + import numpy as np + + np.testing.assert_allclose( + result.trajectory[0].geometry, + [[0.0, 0.0, 0.0], [0.7, 0.5, 0.0], [-0.7, 0.5, 0.0]], + atol=1e-6, + ) assert result.steps == [ OptimizationStep(total_energy=-76.0, max_gradient_component=1e-4) ] diff --git a/tests/test_exess_interaction_energy.py b/tests/test_exess_interaction_energy.py index 5a2b5e2..87c4a4f 100644 --- a/tests/test_exess_interaction_energy.py +++ b/tests/test_exess_interaction_energy.py @@ -1,13 +1,15 @@ +import json from pathlib import Path -from rush import FragmentRef, RunOpts, Topology, exess +from rush import FragmentRef, RunOpts, Topology, exess, get_fragments_near_fragment from tests._module_test_utils import assert_run_collects_and_caches def test_exess_interaction_energy(test_data_dir: Path): - topology = Topology.from_json(test_data_dir / "tyk2_ejm_31_t.json") + with (test_data_dir / "tyk2_ejm_31_t.json").open() as f: + topology = Topology.from_dict(json.load(f)) lig_idx = 93 - frag_idcs = topology.get_fragments_near_fragment(lig_idx, 6.0) + [lig_idx] + frag_idcs = get_fragments_near_fragment(topology, lig_idx, 6.0) + [lig_idx] run = exess.interaction_energy( test_data_dir / "tyk2_ejm_31_t.json", lig_idx, diff --git a/tests/test_exess_interaction_energy_gadi.py b/tests/test_exess_interaction_energy_gadi.py index c1698c9..2ae1550 100644 --- a/tests/test_exess_interaction_energy_gadi.py +++ b/tests/test_exess_interaction_energy_gadi.py @@ -1,17 +1,27 @@ +import json import sys from pathlib import Path import pytest -from rush import FragmentRef, RunOpts, RunSpec, Topology, exess, fetch_run_info +from rush import ( + FragmentRef, + RunOpts, + RunSpec, + Topology, + exess, + fetch_run_info, + get_fragments_near_fragment, +) from tests._module_test_utils import assert_run_collects_and_caches @pytest.mark.timeout(1800) def test_exess_interaction_energy_gadi(test_data_dir: Path): - topology = Topology.from_json(test_data_dir / "tyk2_ejm_31_t.json") + with (test_data_dir / "tyk2_ejm_31_t.json").open() as f: + topology = Topology.from_dict(json.load(f)) lig_idx = 93 - frag_idcs = topology.get_fragments_near_fragment(lig_idx, 6.0) + [lig_idx] + frag_idcs = get_fragments_near_fragment(topology, lig_idx, 6.0) + [lig_idx] run = exess.interaction_energy( test_data_dir / "tyk2_ejm_31_t.json", lig_idx, diff --git a/tests/test_exess_interaction_energy_setonix.py b/tests/test_exess_interaction_energy_setonix.py index ac1fe02..aa05370 100644 --- a/tests/test_exess_interaction_energy_setonix.py +++ b/tests/test_exess_interaction_energy_setonix.py @@ -1,17 +1,27 @@ +import json import sys from pathlib import Path import pytest -from rush import FragmentRef, RunOpts, RunSpec, Topology, exess, fetch_run_info +from rush import ( + FragmentRef, + RunOpts, + RunSpec, + Topology, + exess, + fetch_run_info, + get_fragments_near_fragment, +) from tests._module_test_utils import assert_run_collects_and_caches @pytest.mark.timeout(1800) def test_exess_interaction_energy_setonix(test_data_dir: Path): - topology = Topology.from_json(test_data_dir / "tyk2_ejm_31_t.json") + with (test_data_dir / "tyk2_ejm_31_t.json").open() as f: + topology = Topology.from_dict(json.load(f)) lig_idx = 93 - frag_idcs = topology.get_fragments_near_fragment(lig_idx, 6.0) + [lig_idx] + frag_idcs = get_fragments_near_fragment(topology, lig_idx, 6.0) + [lig_idx] run = exess.interaction_energy( test_data_dir / "tyk2_ejm_31_t.json", lig_idx, diff --git a/uv.lock b/uv.lock index 39a05a3..76237dc 100644 --- a/uv.lock +++ b/uv.lock @@ -450,6 +450,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/91/53255615acd2a1eaca307ede3c90eb550bae9c94581f8c00081b6b1c8f44/kiwisolver-1.5.0-graalpy312-graalpy250_312_native-win_amd64.whl", hash = "sha256:1f1489f769582498610e015a8ef2d36f28f505ab3096d0e16b4858a9ec214f57", size = 75987, upload-time = "2026-03-09T13:15:39.65Z" }, ] +[[package]] +name = "libqdx" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e5/be/83c65426511a9005fce5a920698b3b62836c75ed4daaed917c5e456d4d57/libqdx-0.8.0.tar.gz", hash = "sha256:d3cb57865933d8323c8337181b71bf999024dcf852277702699e3bd6cf5d4ff5", size = 4072101, upload-time = "2026-04-07T03:28:36.178Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/c5/ef8af3b0305401581cf34d988408ba59422354d310a28988f536321681b0/libqdx-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd2cfec5a4dd7791dcedac0ffa59bcea64409303fd635f3a3253bdfdb66541b3", size = 792621, upload-time = "2026-04-07T03:28:25.689Z" }, + { url = "https://files.pythonhosted.org/packages/be/6a/5a8d0ee4a75ba512d921c69a2dc2efecc9d7e5d10d00afae3557310e6048/libqdx-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e8f39f47befb955b2ce04c9201a2f8587b8d73b14c691f84546edf1fbed5594", size = 905390, upload-time = "2026-04-07T03:28:27.524Z" }, + { url = "https://files.pythonhosted.org/packages/7f/5f/99615ba8c73eee3882c9a24802a20ae147351c62101ee53cb1cf6765ffb8/libqdx-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:90bbfac41ad414b3f7b11fe60a6e0cfce078f36114a263d6e0bf192e61494ed4", size = 673340, upload-time = "2026-04-07T03:28:29.421Z" }, + { url = "https://files.pythonhosted.org/packages/88/f8/185bb34c0c36221daa76a8f068011a0ed96960f49aed43ac87339c8c8cb8/libqdx-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:519d9b4016f769027bbee743f1bb9545cb3658ad43892dfc51eced21396f6895", size = 792731, upload-time = "2026-04-07T03:28:31.173Z" }, + { url = "https://files.pythonhosted.org/packages/c7/76/1dca69d2fa5e37ea55e16b32a9b5306339f5e5e12d3ccbacf7cec3964c83/libqdx-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a3540719026a5f30db8faf8873b17d32bc0251b64595db3054573f985627907", size = 904814, upload-time = "2026-04-07T03:28:32.888Z" }, + { url = "https://files.pythonhosted.org/packages/30/76/aa8940fa3b99ff6c05fb20bbafda25aa324b4de8e15d9abcd476e09327dc/libqdx-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:56b7954a3da234a0fec5e9a1eeacc0b925d30b9dcb2719934c8de2985010caaf", size = 673353, upload-time = "2026-04-07T03:28:34.525Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -1168,11 +1185,12 @@ wheels = [ [[package]] name = "rush-py" -version = "7.0.0" +version = "7.1.0" source = { editable = "." } dependencies = [ { name = "gql" }, { name = "h5py" }, + { name = "libqdx" }, { name = "matplotlib" }, { name = "networkx" }, { name = "numpy" }, @@ -1201,6 +1219,7 @@ dev = [ requires-dist = [ { name = "gql", specifier = "~=4.0" }, { name = "h5py", specifier = "~=3.14" }, + { name = "libqdx", specifier = "~=0.8.0" }, { name = "matplotlib", specifier = "~=3.10" }, { name = "networkx", specifier = "~=3.6" }, { name = "numpy", specifier = ">=1.26,<3" },