diff --git a/nvmolkit/_mmff_bridge.py b/nvmolkit/_mmff_bridge.py index c869f80..8b2eaca 100644 --- a/nvmolkit/_mmff_bridge.py +++ b/nvmolkit/_mmff_bridge.py @@ -26,7 +26,6 @@ from __future__ import annotations from typing import TYPE_CHECKING -import weakref from rdkit.Chem import rdForceFieldHelpers from rdkit.ForceField import rdForceField as _rdForceField # noqa: F401 @@ -38,89 +37,13 @@ from rdkit.ForceField.rdForceField import MMFFMolProperties as RDKitMMFFMolProperties -_DEFAULT_MMFF_SETTINGS = { - "variant": "MMFF94", - "dielectric_constant": 1.0, - "dielectric_model": 1, - "bond_term": True, - "angle_term": True, - "stretch_bend_term": True, - "oop_term": True, - "torsion_term": True, - "vdw_term": True, - "ele_term": True, -} -_CAPTURED_MMFF_SETTINGS_BY_OBJECT: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() - - -def _normalize_mmff_settings(settings: dict | None) -> dict: - normalized = dict(_DEFAULT_MMFF_SETTINGS) - if settings is not None: - normalized.update(settings) - return normalized - - -def capture_mmff_settings(properties: "RDKitMMFFMolProperties", settings: dict | None): - """Associate explicit MMFF settings with an RDKit MMFF properties object. - - We use this when nvmolkit itself creates/configures a Python - ``MMFFMolProperties`` object so later conversion back into nvMolKit's - internal MMFF transport does not depend on RDKit exposing Python getters for - every setting in a given build. - """ - - _CAPTURED_MMFF_SETTINGS_BY_OBJECT[properties] = _normalize_mmff_settings(settings) - return properties - - def default_rdkit_mmff_properties(mol: "Mol"): - """Create default RDKit MMFF properties and capture their default settings.""" + """Create default RDKit MMFF properties for ``mol``.""" properties = rdForceFieldHelpers.MMFFGetMoleculeProperties(mol) if properties is None: raise ValueError("RDKit could not create MMFF properties for molecule") - return capture_mmff_settings(properties, None) - - -def extract_mmff_settings(properties: "RDKitMMFFMolProperties") -> dict: - """Return MMFF settings for an RDKit MMFF properties object. - - Captured settings take precedence. If we did not create/configure the object - ourselves, fall back to RDKit getter methods when the current build exposes - them. - """ - - captured = _CAPTURED_MMFF_SETTINGS_BY_OBJECT.get(properties) - if captured is not None: - return dict(captured) - - getter_candidates = { - "variant": ("GetMMFFVariant", "getMMFFVariant"), - "dielectric_constant": ("GetMMFFDielectricConstant", "getMMFFDielectricConstant"), - "dielectric_model": ("GetMMFFDielectricModel", "getMMFFDielectricModel"), - "bond_term": ("GetMMFFBondTerm", "getMMFFBondTerm"), - "angle_term": ("GetMMFFAngleTerm", "getMMFFAngleTerm"), - "stretch_bend_term": ("GetMMFFStretchBendTerm", "getMMFFStretchBendTerm"), - "oop_term": ("GetMMFFOopTerm", "getMMFFOopTerm"), - "torsion_term": ("GetMMFFTorsionTerm", "getMMFFTorsionTerm"), - "vdw_term": ("GetMMFFVdWTerm", "getMMFFVdWTerm"), - "ele_term": ("GetMMFFEleTerm", "getMMFFEleTerm"), - } - extracted = {} - for key, names in getter_candidates.items(): - getter = None - for name in names: - if hasattr(properties, name): - getter = getattr(properties, name) - break - if getter is None: - raise TypeError( - "Could not read MMFF settings from the supplied RDKit MMFFMolProperties object. " - "Use an object created via rdForceFieldHelpers.MMFFGetMoleculeProperties() and " - "configured before passing it to nvmolkit." - ) - extracted[key] = getter() - return _normalize_mmff_settings(extracted) + return properties def make_internal_mmff_properties( @@ -131,24 +54,13 @@ def make_internal_mmff_properties( ): """Convert an RDKit MMFF properties object into nvMolKit's internal transport. - Unlike RDKit ``Mol`` objects, RDKit's Python ``MMFFMolProperties`` wrapper is - not passed directly into nvMolKit's extension module. The native code instead - receives this plain internal ``MMFFProperties`` object with the RDKit settings - copied onto it. + RDKit's Python binding only exposes setters for the scalar MMFF settings + (variant, dielectric, per-term flags); the corresponding getters are not + wrapped. We read the settings through the C++ binding layer instead. """ - settings = extract_mmff_settings(properties) - internal = _batchedForcefield.MMFFProperties() - internal.variant = str(settings["variant"]) - internal.dielectricConstant = float(settings["dielectric_constant"]) - internal.dielectricModel = int(settings["dielectric_model"]) - internal.nonBondedThreshold = float(non_bonded_threshold) - internal.ignoreInterfragInteractions = bool(ignore_interfrag_interactions) - internal.bondTerm = bool(settings["bond_term"]) - internal.angleTerm = bool(settings["angle_term"]) - internal.stretchBendTerm = bool(settings["stretch_bend_term"]) - internal.oopTerm = bool(settings["oop_term"]) - internal.torsionTerm = bool(settings["torsion_term"]) - internal.vdwTerm = bool(settings["vdw_term"]) - internal.eleTerm = bool(settings["ele_term"]) - return internal + return _batchedForcefield.buildMMFFPropertiesFromRDKit( + properties, + float(non_bonded_threshold), + bool(ignore_interfrag_interactions), + ) diff --git a/nvmolkit/batchedForcefield.cpp b/nvmolkit/batchedForcefield.cpp index 0af0943..7a196de 100644 --- a/nvmolkit/batchedForcefield.cpp +++ b/nvmolkit/batchedForcefield.cpp @@ -422,6 +422,11 @@ BOOST_PYTHON_MODULE(_batchedForcefield) { .def_readwrite("vdwTerm", &nvMolKit::MMFFProperties::vdwTerm) .def_readwrite("eleTerm", &nvMolKit::MMFFProperties::eleTerm); + bp::def("buildMMFFPropertiesFromRDKit", + &nvMolKit::buildMMFFPropertiesFromRDKit, + (bp::arg("rdkit_properties"), bp::arg("non_bonded_threshold"), bp::arg("ignore_interfrag_interactions")), + "Build an nvMolKit MMFFProperties transport from an RDKit MMFFMolProperties Python object."); + bp::class_("NativeMMFFBatchedForcefield", bp::init + #include +#include +#include #include #include "mmff_properties.h" +namespace ForceFields { + +/// \brief Layout-compatible shim for RDKit's Python wrapper class. +/// +/// RDKit registers its Python \c MMFFMolProperties binding as +/// \c ForceFields::PyMMFFMolProperties (declared in the un-installed Wrap-layer +/// header \c Code/ForceField/Wrap/PyForceField.h) which holds the real C++ +/// \c RDKit::MMFF::MMFFMolProperties via a public \c boost::shared_ptr member. +/// This declaration matches the layout of RDKit's class so Boost.Python's type +/// registry maps the Python object to the same \c type_info, letting us read +/// the underlying \c MMFFMolProperties through its public shared_ptr member. +/// The class has no virtual methods in RDKit, so only the single member needs +/// to match for RTTI-based lookup to resolve. +/// +/// \warning This is brittle and relies on Linux symbol resolution to make the +/// \c type_info here compare equal to RDKit's. The long-term fix is for RDKit +/// to expose the scalar-setting getters on its Python binding so this shim +/// can go away. +/// TODO: This will go away after https://github.com/rdkit/rdkit/issues/9253 is implemented +/// but we'll need to keep it as backup as long as we support older versions of RDKit. +class PyMMFFMolProperties { + public: + boost::shared_ptr mmffMolProperties; +}; + +static_assert(sizeof(PyMMFFMolProperties) == sizeof(boost::shared_ptr), + "nvMolKit's PyMMFFMolProperties shim must hold exactly one boost::shared_ptr; " + "adding fields or virtual methods here breaks the layout contract with RDKit."); + +} // namespace ForceFields + namespace nvMolKit { +/// \brief Populate an nvMolKit MMFF transport from an RDKit MMFFMolProperties Python object. +/// +/// RDKit's Python binding for \c MMFFMolProperties only exposes setters for the +/// scalar settings (variant, dielectric, per-term flags); the corresponding getters +/// are unbound. This helper peeks at the underlying C++ +/// \c RDKit::MMFF::MMFFMolProperties through RDKit's Python wrapper shim and reads +/// the settings with the C++ getters directly. +inline MMFFProperties buildMMFFPropertiesFromRDKit(const boost::python::object& pyProps, + double nonBondedThreshold, + bool ignoreInterfragInteractions) { + boost::python::extract extractor(pyProps); + if (!extractor.check()) { + throw std::invalid_argument("buildMMFFPropertiesFromRDKit: expected an RDKit MMFFMolProperties object"); + } + ForceFields::PyMMFFMolProperties* pyWrapper = extractor(); + if (pyWrapper == nullptr || pyWrapper->mmffMolProperties.get() == nullptr) { + throw std::invalid_argument("buildMMFFPropertiesFromRDKit: null MMFFMolProperties pointer"); + } + RDKit::MMFF::MMFFMolProperties* rdProps = pyWrapper->mmffMolProperties.get(); + + MMFFProperties props; + props.variant = rdProps->getMMFFVariant(); + props.dielectricConstant = rdProps->getMMFFDielectricConstant(); + props.dielectricModel = static_cast(rdProps->getMMFFDielectricModel()); + props.nonBondedThreshold = nonBondedThreshold; + props.ignoreInterfragInteractions = ignoreInterfragInteractions; + props.bondTerm = rdProps->getMMFFBondTerm(); + props.angleTerm = rdProps->getMMFFAngleTerm(); + props.stretchBendTerm = rdProps->getMMFFStretchBendTerm(); + props.oopTerm = rdProps->getMMFFOopTerm(); + props.torsionTerm = rdProps->getMMFFTorsionTerm(); + props.vdwTerm = rdProps->getMMFFVdWTerm(); + props.eleTerm = rdProps->getMMFFEleTerm(); + return props; +} + inline MMFFProperties extractMMFFProperties(const boost::python::object& obj, double nonBondedThreshold = 100.0, bool ignoreInterfragInteractions = true) { diff --git a/nvmolkit/tests/test_batched_forcefield.py b/nvmolkit/tests/test_batched_forcefield.py index 0c12368..fdb5fdd 100644 --- a/nvmolkit/tests/test_batched_forcefield.py +++ b/nvmolkit/tests/test_batched_forcefield.py @@ -22,7 +22,6 @@ from rdkit.ForceField import rdForceField as _rdForceField # noqa: F401 from rdkit.Geometry import Point3D -from nvmolkit._mmff_bridge import capture_mmff_settings from nvmolkit.batchedForcefield import MMFFBatchedForcefield, UFFBatchedForcefield from nvmolkit.types import HardwareOptions @@ -120,7 +119,7 @@ def make_rdkit_mmff_properties(mol, settings: dict | None = None): mmff_props.SetMMFFTorsionTerm(settings.get("torsion_term", True)) mmff_props.SetMMFFVdWTerm(settings.get("vdw_term", True)) mmff_props.SetMMFFEleTerm(settings.get("ele_term", True)) - return capture_mmff_settings(mmff_props, settings) + return mmff_props def make_rdkit_mmff_forcefield( @@ -296,6 +295,37 @@ def test_mmff_batched_forcefield_properties_match_rdkit(): ) +def test_mmff_batched_forcefield_reads_externally_configured_properties(): + """Configure RDKit MMFF properties via raw ``rdForceFieldHelpers.MMFFGetMoleculeProperties`` + plus direct ``SetMMFF*Term``/``SetMMFFDielectricConstant`` calls — no nvmolkit helpers + in the path — then hand the object to ``MMFFBatchedForcefield``. + + Needed because of our workaround for RDKit bug https://github.com/rdkit/rdkit/issues/9253 + """ + mol = make_embedded_mol("CCO") + + props = rdForceFieldHelpers.MMFFGetMoleculeProperties(mol) + assert props is not None + props.SetMMFFBondTerm(False) + props.SetMMFFTorsionTerm(False) + props.SetMMFFDielectricConstant(2.5) + + forcefield = MMFFBatchedForcefield(clone_mols([mol]), properties=[props]) + got_energy = forcefield.compute_energy()[0][0] + got_grad = forcefield.compute_gradients()[0][0] + + rd_ff = rdForceFieldHelpers.MMFFGetMoleculeForceField(mol, props) + rd_energy = rd_ff.CalcEnergy() + rd_grad = list(rd_ff.CalcGrad()) + assert_energy_and_gradient_close(got_energy, rd_energy, got_grad, rd_grad) + + default_forcefield = MMFFBatchedForcefield(clone_mols([mol])) + default_energy = default_forcefield.compute_energy()[0][0] + assert abs(got_energy - default_energy) > 1e-6, ( + "term toggles on externally-configured MMFFMolProperties had no observable effect on the batched energy" + ) + + def test_mmff_batched_forcefield_constraints_match_rdkit(): """Batch of mols with all 5 MMFF constraint types applied (one per mol), some also carrying non-default property settings to exercise the properties+constraints path.""" diff --git a/nvmolkit/tests/test_mmff_optimization.py b/nvmolkit/tests/test_mmff_optimization.py index 158ba53..48b19f5 100644 --- a/nvmolkit/tests/test_mmff_optimization.py +++ b/nvmolkit/tests/test_mmff_optimization.py @@ -22,7 +22,6 @@ from rdkit.ForceField import rdForceField as _rdForceField # noqa: F401 from rdkit.Geometry import Point3D -from nvmolkit._mmff_bridge import capture_mmff_settings from nvmolkit.embedMolecules import EmbedMolecules import nvmolkit.mmffOptimization as nvmolkit_mmff from nvmolkit.types import HardwareOptions @@ -120,7 +119,7 @@ def make_rdkit_mmff_properties(mol, settings: dict | None = None): mmff_props.SetMMFFTorsionTerm(settings.get("torsion_term", True)) mmff_props.SetMMFFVdWTerm(settings.get("vdw_term", True)) mmff_props.SetMMFFEleTerm(settings.get("ele_term", True)) - return capture_mmff_settings(mmff_props, settings) + return mmff_props def calculate_rdkit_mmff_energies(