Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 55 additions & 10 deletions src/dynsight/_internal/lens/lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,35 @@
from MDAnalysis import AtomGroup, Universe
from numpy.typing import NDArray

from multiprocessing import Pool

import numpy as np
from MDAnalysis.lib.NeighborSearch import AtomNeighborSearch


def _process_neighbour_frame(
args: tuple[Universe, AtomGroup, float, int, int],
) -> tuple[int, list[NDArray[np.float64]]]:
universe, selection, cutoff, traj_frame, result_index = args

universe.trajectory[traj_frame]
neigh_search = AtomNeighborSearch(
universe.atoms, box=universe.trajectory.ts.dimensions
)

neigh_list_per_atom = [
neigh_search.search(atom, cutoff).ix for atom in selection
]

return result_index, neigh_list_per_atom


def list_neighbours_along_trajectory(
input_universe: Universe,
cutoff: float,
selection: str = "all",
trajslice: slice | None = None,
num_processes: int = 1,
) -> list[list[AtomGroup]]:
"""Produce a per-frame list of the neighbors, atom by atom.

Expand All @@ -33,6 +53,9 @@ def list_neighbours_along_trajectory(
`here <https://userguide.mdanalysis.org/stable/selections.html>`_
trajslice:
The slice of the trajectory to consider. Defaults to slice(None).
num_processes:
The number of processes to use for parallel computation.
**Warning:** Adjust this based on the available cores.

Returns:
list[list[AtomGroup]]:
Expand Down Expand Up @@ -70,18 +93,40 @@ def list_neighbours_along_trajectory(
"""
if trajslice is None:
trajslice = slice(None)
neigh_list_per_frame = []
selected_atoms = input_universe.select_atoms(selection)
for _ in input_universe.universe.trajectory[trajslice]:
neigh_search = AtomNeighborSearch(
input_universe.atoms, box=input_universe.dimensions
)
frame_indices = list(
range(*trajslice.indices(input_universe.trajectory.n_frames))
)

if num_processes < 1:
msg = "num_processes cannot be negative or zero."
raise ValueError(msg)

if num_processes == 1:
neigh_list_per_frame = []
for traj_frame in frame_indices:
input_universe.trajectory[traj_frame]
neigh_search = AtomNeighborSearch(
input_universe.atoms,
box=input_universe.trajectory.ts.dimensions,
)
neigh_list_per_atom = [
neigh_search.search(atom, cutoff).ix for atom in selected_atoms
]
neigh_list_per_frame.append(neigh_list_per_atom)
return neigh_list_per_frame

args = [
(input_universe, selected_atoms, cutoff, frame, i)
for i, frame in enumerate(frame_indices)
]

with Pool(processes=num_processes) as pool:
results = pool.map(_process_neighbour_frame, args)

ordered_results = dict(results)

neigh_list_per_atom = [
neigh_search.search(atom, cutoff) for atom in selected_atoms
]
neigh_list_per_frame.append([at.ix for at in neigh_list_per_atom])
return neigh_list_per_frame
return [ordered_results[i] for i in range(len(frame_indices))]


def neighbour_change_in_time(
Expand Down
2 changes: 2 additions & 0 deletions src/dynsight/_internal/trajectory/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def get_lens(
delay: int = 1,
selection: str = "all",
neigcounts: list[list[AtomGroup]] | None = None,
num_processes: int = 1,
) -> tuple[list[list[AtomGroup]], Insight]:
"""Compute LENS on the trajectory.

Expand All @@ -185,6 +186,7 @@ def get_lens(
cutoff=r_cut,
selection=selection,
trajslice=self.trajslice,
num_processes=num_processes,
)
lens, *_ = dynsight.lens.neighbour_change_in_time(
neigh_list_per_frame=neigcounts,
Expand Down
7 changes: 7 additions & 0 deletions tests/lens/case_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dataclasses import dataclass


@dataclass(frozen=True, slots=True)
class LENSCaseData:
num_processes: int
name: str
23 changes: 23 additions & 0 deletions tests/lens/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pytest

from tests.lens.case_data import LENSCaseData


def fewframeuniverse(
trajectory: list[list[list[int]]],
Expand Down Expand Up @@ -261,3 +263,24 @@ def lensisonefixtures() -> MDAnalysis.Universe:
)
def lensfixtures(request: pytest.FixtureRequest) -> MDAnalysis.Universe:
return getuni[request.param]


@pytest.fixture(
scope="session",
params=(
# Case 0: Mono Core
lambda name: LENSCaseData(
num_processes=1,
name=name,
),
# Case 1: Multi Core
lambda name: LENSCaseData(
num_processes=2,
name=name,
),
),
)
def case_data(request: pytest.FixtureRequest) -> LENSCaseData:
return request.param(
f"{request.fixturename}{request.param_index}", # type: ignore [attr-defined]
)
8 changes: 6 additions & 2 deletions tests/lens/test_lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from dynsight.trajectory import Trj

from .case_data import LENSCaseData


# Define the actual test
def test_lens_signals() -> None:
def test_lens_signals(case_data: LENSCaseData) -> None:
"""Test the consistency of LENS calculations with a control calculation.

* Original author: Martina Crippa
Expand Down Expand Up @@ -39,7 +41,9 @@ def test_lens_signals() -> None:

# Run LENS (and nn) calculation for different r_cuts
for i, r_cut in enumerate(lens_cutoffs):
neigcounts, test_lens = example_trj.get_lens(r_cut=r_cut)
neigcounts, test_lens = example_trj.get_lens(
r_cut=r_cut, num_processes=case_data.num_processes
)
test_lens_ds = np.array(
[np.concatenate(([0.0], tmp)) for tmp in test_lens.dataset]
) # the inner LENS function has always 0.0 as first frame
Expand Down
21 changes: 15 additions & 6 deletions tests/vision/test_vision.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from unittest.mock import patch

import numpy as np
import yaml
Expand Down Expand Up @@ -180,12 +181,20 @@ def test_vision_tuning(tmp_path: Path) -> None:
)
create_dummy_yolo_dataset(tmp_path)
instance.set_training_dataset(tmp_path / "data.yaml")
hyp = instance.tune_hyperparams(
iterations=1,
epochs=1,
imgsz=100,
batch_size=-1,
)

# Disable Ultralytics plot_tune_results
# It crashes on empty CSVs with dummy datasets
# in newer version of ultralytics package.
with patch(
"ultralytics.utils.plotting.plot_tune_results", lambda *_, **__: None
):
hyp = instance.tune_hyperparams(
iterations=1,
epochs=1,
imgsz=100,
batch_size=1,
)

assert (
out_path / "tuning" / "results" / "best_hyperparameters.yaml"
).exists()
Expand Down