From 0532691088ee49dee2fa8f928d0da8869800e131 Mon Sep 17 00:00:00 2001 From: SimoneMartino98 Date: Thu, 25 Sep 2025 17:30:19 +0200 Subject: [PATCH 1/4] parallelized neighbor LENS calculation. --- src/dynsight/_internal/lens/lens.py | 61 ++++++++++++++++--- .../_internal/trajectory/trajectory.py | 2 + 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/dynsight/_internal/lens/lens.py b/src/dynsight/_internal/lens/lens.py index 987066b6..d5c4b64b 100644 --- a/src/dynsight/_internal/lens/lens.py +++ b/src/dynsight/_internal/lens/lens.py @@ -8,15 +8,35 @@ from MDAnalysis import AtomGroup, Universe from numpy.typing import NDArray +from multiprocessing import Pool + import numpy as np from MDAnalysis.lib.NeighborSearch import AtomNeighborSearch +def _process_neighbour_frame( + args: tuple[Universe, AtomGroup, float, int, int], +) -> tuple[int, list[NDArray[np.float64]]]: + universe, selection, cutoff, traj_frame, result_index = args + + universe.trajectory[traj_frame] + neigh_search = AtomNeighborSearch( + universe.atoms, box=universe.trajectory.ts.dimensions + ) + + neigh_list_per_atom = [ + neigh_search.search(atom, cutoff).ix for atom in selection + ] + + return result_index, neigh_list_per_atom + + def list_neighbours_along_trajectory( input_universe: Universe, cutoff: float, selection: str = "all", trajslice: slice | None = None, + num_processes: int = 1, ) -> list[list[AtomGroup]]: """Produce a per-frame list of the neighbors, atom by atom. @@ -33,6 +53,9 @@ def list_neighbours_along_trajectory( `here `_ trajslice: The slice of the trajectory to consider. Defaults to slice(None). + num_processes: + The number of processes to use for parallel computation. + **Warning:** Adjust this based on the available cores. Returns: list[list[AtomGroup]]: @@ -70,18 +93,36 @@ def list_neighbours_along_trajectory( """ if trajslice is None: trajslice = slice(None) - neigh_list_per_frame = [] selected_atoms = input_universe.select_atoms(selection) - for _ in input_universe.universe.trajectory[trajslice]: - neigh_search = AtomNeighborSearch( - input_universe.atoms, box=input_universe.dimensions - ) + frame_indices = list( + range(*trajslice.indices(input_universe.trajectory.n_frames)) + ) + + if num_processes <= 1: + neigh_list_per_frame = [] + for traj_frame in frame_indices: + input_universe.trajectory[traj_frame] + neigh_search = AtomNeighborSearch( + input_universe.atoms, + box=input_universe.trajectory.ts.dimensions, + ) + neigh_list_per_atom = [ + neigh_search.search(atom, cutoff).ix for atom in selected_atoms + ] + neigh_list_per_frame.append(neigh_list_per_atom) + return neigh_list_per_frame + + args = [ + (input_universe, selected_atoms, cutoff, frame, i) + for i, frame in enumerate(frame_indices) + ] + + with Pool(processes=num_processes) as pool: + results = pool.map(_process_neighbour_frame, args) + + ordered_results = dict(results) - neigh_list_per_atom = [ - neigh_search.search(atom, cutoff) for atom in selected_atoms - ] - neigh_list_per_frame.append([at.ix for at in neigh_list_per_atom]) - return neigh_list_per_frame + return [ordered_results[i] for i in range(len(frame_indices))] def neighbour_change_in_time( diff --git a/src/dynsight/_internal/trajectory/trajectory.py b/src/dynsight/_internal/trajectory/trajectory.py index ec6df090..d45d3e29 100644 --- a/src/dynsight/_internal/trajectory/trajectory.py +++ b/src/dynsight/_internal/trajectory/trajectory.py @@ -169,6 +169,7 @@ def get_lens( delay: int = 1, selection: str = "all", neigcounts: list[list[AtomGroup]] | None = None, + num_processes: int = 1, ) -> tuple[list[list[AtomGroup]], Insight]: """Compute LENS on the trajectory. @@ -185,6 +186,7 @@ def get_lens( cutoff=r_cut, selection=selection, trajslice=self.trajslice, + num_processes=num_processes, ) lens, *_ = dynsight.lens.neighbour_change_in_time( neigh_list_per_frame=neigcounts, From 3a2d184412d221a8f234ae77994c8c24fb9d4516 Mon Sep 17 00:00:00 2001 From: SimoneMartino98 Date: Fri, 26 Sep 2025 09:37:32 +0200 Subject: [PATCH 2/4] added check for negative n_cores. --- src/dynsight/_internal/lens/lens.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/dynsight/_internal/lens/lens.py b/src/dynsight/_internal/lens/lens.py index d5c4b64b..f78162ea 100644 --- a/src/dynsight/_internal/lens/lens.py +++ b/src/dynsight/_internal/lens/lens.py @@ -97,8 +97,12 @@ def list_neighbours_along_trajectory( frame_indices = list( range(*trajslice.indices(input_universe.trajectory.n_frames)) ) - - if num_processes <= 1: + + if num_processes < 1: + msg="num_processes cannot be negative or zero." + raise ValueError(msg) + + if num_processes == 1: neigh_list_per_frame = [] for traj_frame in frame_indices: input_universe.trajectory[traj_frame] From 1b56caf36ff1c0da054b23ba9e533daf417f903f Mon Sep 17 00:00:00 2001 From: SimoneMartino98 Date: Fri, 26 Sep 2025 10:16:22 +0200 Subject: [PATCH 3/4] Fix plot error during vision tests and formatting. --- src/dynsight/_internal/lens/lens.py | 6 +++--- tests/vision/test_vision.py | 21 +++++++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/dynsight/_internal/lens/lens.py b/src/dynsight/_internal/lens/lens.py index f78162ea..0da9ea9f 100644 --- a/src/dynsight/_internal/lens/lens.py +++ b/src/dynsight/_internal/lens/lens.py @@ -97,11 +97,11 @@ def list_neighbours_along_trajectory( frame_indices = list( range(*trajslice.indices(input_universe.trajectory.n_frames)) ) - + if num_processes < 1: - msg="num_processes cannot be negative or zero." + msg = "num_processes cannot be negative or zero." raise ValueError(msg) - + if num_processes == 1: neigh_list_per_frame = [] for traj_frame in frame_indices: diff --git a/tests/vision/test_vision.py b/tests/vision/test_vision.py index 32398fb2..b8243036 100644 --- a/tests/vision/test_vision.py +++ b/tests/vision/test_vision.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from unittest.mock import patch import numpy as np import yaml @@ -180,12 +181,20 @@ def test_vision_tuning(tmp_path: Path) -> None: ) create_dummy_yolo_dataset(tmp_path) instance.set_training_dataset(tmp_path / "data.yaml") - hyp = instance.tune_hyperparams( - iterations=1, - epochs=1, - imgsz=100, - batch_size=-1, - ) + + # Disable Ultralytics plot_tune_results + # It crashes on empty CSVs with dummy datasets + # in newer version of ultralytics package. + with patch( + "ultralytics.utils.plotting.plot_tune_results", lambda *_, **__: None + ): + hyp = instance.tune_hyperparams( + iterations=1, + epochs=1, + imgsz=100, + batch_size=1, + ) + assert ( out_path / "tuning" / "results" / "best_hyperparameters.yaml" ).exists() From d145d57673f5f2be3d34f063a912c767d25991a2 Mon Sep 17 00:00:00 2001 From: SimoneMartino98 Date: Fri, 26 Sep 2025 11:26:20 +0200 Subject: [PATCH 4/4] Added case_data to test both mono and multi core. --- tests/lens/case_data.py | 7 +++++++ tests/lens/conftest.py | 23 +++++++++++++++++++++++ tests/lens/test_lens.py | 8 ++++++-- 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 tests/lens/case_data.py diff --git a/tests/lens/case_data.py b/tests/lens/case_data.py new file mode 100644 index 00000000..262d0322 --- /dev/null +++ b/tests/lens/case_data.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class LENSCaseData: + num_processes: int + name: str diff --git a/tests/lens/conftest.py b/tests/lens/conftest.py index b36d14b3..a679195d 100644 --- a/tests/lens/conftest.py +++ b/tests/lens/conftest.py @@ -4,6 +4,8 @@ import numpy as np import pytest +from tests.lens.case_data import LENSCaseData + def fewframeuniverse( trajectory: list[list[list[int]]], @@ -261,3 +263,24 @@ def lensisonefixtures() -> MDAnalysis.Universe: ) def lensfixtures(request: pytest.FixtureRequest) -> MDAnalysis.Universe: return getuni[request.param] + + +@pytest.fixture( + scope="session", + params=( + # Case 0: Mono Core + lambda name: LENSCaseData( + num_processes=1, + name=name, + ), + # Case 1: Multi Core + lambda name: LENSCaseData( + num_processes=2, + name=name, + ), + ), +) +def case_data(request: pytest.FixtureRequest) -> LENSCaseData: + return request.param( + f"{request.fixturename}{request.param_index}", # type: ignore [attr-defined] + ) diff --git a/tests/lens/test_lens.py b/tests/lens/test_lens.py index d878ddce..b01dd0ee 100644 --- a/tests/lens/test_lens.py +++ b/tests/lens/test_lens.py @@ -5,9 +5,11 @@ from dynsight.trajectory import Trj +from .case_data import LENSCaseData + # Define the actual test -def test_lens_signals() -> None: +def test_lens_signals(case_data: LENSCaseData) -> None: """Test the consistency of LENS calculations with a control calculation. * Original author: Martina Crippa @@ -39,7 +41,9 @@ def test_lens_signals() -> None: # Run LENS (and nn) calculation for different r_cuts for i, r_cut in enumerate(lens_cutoffs): - neigcounts, test_lens = example_trj.get_lens(r_cut=r_cut) + neigcounts, test_lens = example_trj.get_lens( + r_cut=r_cut, num_processes=case_data.num_processes + ) test_lens_ds = np.array( [np.concatenate(([0.0], tmp)) for tmp in test_lens.dataset] ) # the inner LENS function has always 0.0 as first frame