diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 3921413d..fa28dfa1 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,78 +1,78 @@ name: Tests on: - push: - branches: - - main - pull_request: - workflow_dispatch: + push: + branches: + - main + pull_request: + workflow_dispatch: jobs: - ruff: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.13"] - name: ruff on ${{ matrix.python-version }} - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - run: pip install -e '.[dev]' - - run: ruff check . - mypy: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.13"] - name: mypy on ${{ matrix.python-version }} - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - run: pip install -e '.[dev]' - - run: mypy . - ruff-format: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.13"] - name: ruff-format on ${{ matrix.python-version }} - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - run: pip install -e '.[dev]' - - run: ruff format --check . - pytest: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.13"] - name: pytest on ${{ matrix.python-version }} - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - run: pip install -e '.[dev]' - - run: pytest --cov=src --cov-report term-missing - doctest: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.13"] - name: doctest on ${{ matrix.python-version }} - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - run: pip install -e '.[dev]' - - run: make -C docs doctest + ruff: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.13"] + name: ruff on ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - run: pip install -e '.[dev]' + - run: ruff check . + mypy: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.13"] + name: mypy on ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - run: pip install -e '.[dev]' + - run: mypy . + ruff-format: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.13"] + name: ruff-format on ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - run: pip install -e '.[dev]' + - run: ruff format --check . + pytest: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.13"] + name: pytest on ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - run: pip install -e '.[dev]' + - run: pytest --cov=src --cov-report term-missing + doctest: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.13"] + name: doctest on ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + - run: pip install -e '.[dev]' + - run: make -C docs doctest diff --git a/docs/source/_static/label_bar.png b/docs/source/_static/label_bar.png new file mode 100644 index 00000000..da11d389 Binary files /dev/null and b/docs/source/_static/label_bar.png differ diff --git a/docs/source/_static/label_menu.png b/docs/source/_static/label_menu.png new file mode 100644 index 00000000..f5b84609 Binary files /dev/null and b/docs/source/_static/label_menu.png differ diff --git a/docs/source/_static/label_tool.png b/docs/source/_static/label_tool.png new file mode 100644 index 00000000..2123ba38 Binary files /dev/null and b/docs/source/_static/label_tool.png differ diff --git a/docs/source/_static/style.css b/docs/source/_static/style.css new file mode 100644 index 00000000..515773f3 --- /dev/null +++ b/docs/source/_static/style.css @@ -0,0 +1,23 @@ +/*DOCS SIDEBARS AND CENTERING*/ +.sidebar-drawer, +.sidebar-container, +.toc-drawer { + width: 15em; + min-width: 15em; +} +@media (min-width: 65em) { + .content { + margin-left: auto; + margin-right: auto; + } +} +@media (min-width: 82em) { + .page { + justify-content: center; + } + .toc-drawer { + position: fixed; + right: 0; + top: 0; + } +} diff --git a/docs/source/conf.py b/docs/source/conf.py index 376d1d05..d6008384 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -8,6 +8,10 @@ from __future__ import annotations import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sphinx.application import Sphinx project = "dynsight" project_copyright = "2023, Andrew Tarzia" @@ -37,7 +41,7 @@ autodoc_typehints = "description" autodoc_member_order = "groupwise" -autoclass_content = "class" +autoclass_content = "both" intersphinx_mapping = { "python": ("https://docs.python.org/3", None), @@ -53,3 +57,8 @@ html_theme = "furo" html_static_path = ["_static"] + + +def setup(app: Sphinx) -> None: + """Configure the Sphinx app by adding a custom CSS file.""" + app.add_css_file("style.css") diff --git a/docs/source/index.rst b/docs/source/index.rst index a4d0e965..a52ae8ac 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,6 +10,7 @@ :maxdepth: 2 vision + track descriptors onion clustering analysis @@ -75,10 +76,6 @@ you are using Python 3.10 and below, you can use :mod:`cpctools` to access $ pip install cpctools -If you want to use the :mod:`dynsight.vision` and :mod:`dynsight.track` modules -you will need to install a series of packages. This can be done with with pip:: - - $ pip install ultralytics PyYAML How to get started ------------------ diff --git a/docs/source/label_tool/label_tool.rst b/docs/source/label_tool/label_tool.rst new file mode 100644 index 00000000..2882baec --- /dev/null +++ b/docs/source/label_tool/label_tool.rst @@ -0,0 +1,70 @@ +The Label Tool +============== + +The ``dynsight label_tool`` is a simple web application that allows users to +label images. Picture labelling is a crucial step in many computer vision tasks, +such as the creation of initial training dataset to train Convolutional Neural +Networks (CNNs) model. The current version of `dynsight vision <../_autosummary/dynsight.vision.VisionInstance.html>`_ +exploits the power of the `YOLO models `_ +for computer vision tasks. Thus, the ``label_tool`` has been specifically +designed to work with the YOLO dataset format. + +.. image:: ../_static/label_tool.png + +---------- +How to Use +---------- + +The ``label_tool`` application can be executed in 2 main ways: + +* As a standalone application, run the following command in the environment + where dynsight is installed: + +.. code-block:: bash + + $ label_tool + +* From python code: + +.. code-block:: python + + import dynsight + + dynsight.vision.label_tool(port=8888) #port selection is optional + +In both cases a localhost server should start and the application should +automatically appear in your default web browser. + +.. tip:: + + In case the application does + not open automatically, you can manually open it by copying and pasting + the URL provided in the terminal output. + +------- +The GUI +------- + +The ``label_tool`` Graphical User Interface is divided in three main panels: + +* **The image panel**: where loaded images appear and labels can be drawn. + +* **The label menu panel**: where labels can be created and edited. + +.. image:: ../_static/label_menu.png + :align: center + +* **The commands panel**: where all the available commands can be executed. + +.. image:: ../_static/label_bar.png + +Using the ``Choose File`` button, users can select the image(s) they want to +label. Once the image is loaded, users can start drawing labels by clicking and +dragging on the image panel. The label menu panel allows users to create and +edit labels. Finally, the commands panel provides a set of exporting options: + +* **Export label**: Download a single ``.txt`` file in YOLO format containing the labels for the current image. + +* **Export dataset**: Download a YOLO dataset from the loaded images with the labels and create the initial yaml configuration file to be used in the YOLO training process. + +* **Synthesize dataset**: Create a synthetic dataset from the drawn labels randomizing the object position in different images (useful when a low number of images is available). diff --git a/docs/source/track.rst b/docs/source/track.rst new file mode 100644 index 00000000..2a1049f7 --- /dev/null +++ b/docs/source/track.rst @@ -0,0 +1,14 @@ +track +===== + +This module tracks particle trajectories in `.xyz` files that lack explicit +particle IDs, assigning consistent identities across frames. +It is especially useful for outputs generated by the :doc:`vision` module. + +Functions +--------- + +.. toctree:: + :maxdepth: 1 + + track_xyz <_autosummary/dynsight.track.track_xyz> diff --git a/docs/source/vision.rst b/docs/source/vision.rst index 4f08e89e..62b7264a 100644 --- a/docs/source/vision.rst +++ b/docs/source/vision.rst @@ -10,5 +10,5 @@ Usage .. toctree:: :maxdepth: 1 - Video <_autosummary/dynsight.vision.Video.rst> - Detect <_autosummary/dynsight.vision.Detect.rst> + label_tool + VisionInstance <_autosummary/dynsight.vision.VisionInstance.rst> diff --git a/examples/video_detection.py b/examples/video_detection.py deleted file mode 100644 index 9d5695cb..00000000 --- a/examples/video_detection.py +++ /dev/null @@ -1,123 +0,0 @@ -import argparse -from pathlib import Path - -import dynsight - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Dynsight video detection") - parser.add_argument( - "--prepare", - action="store_true", - help="Prepare synthesized data.", - ) - parser.add_argument( - "--train", - action="store_true", - help="Train the model.", - ) - parser.add_argument( - "--predict", - action="store_true", - help="Predict the video.", - ) - parser.add_argument( - "--input", - type=Path, - required=True, - help="Path to input file.", - ) - parser.add_argument( - "--output", - type=Path, - required=True, - help="Path to output folder.", - ) - parser.add_argument( - "--maxcycle", - type=int, - required=False, - default=5, - help="Number of fitting cycle procedures.", - ) - parser.add_argument( - "--epochs", - type=int, - required=False, - default=100, - help="Number of epochs for each training session.", - ) - parser.add_argument( - "--patience", - type=int, - required=False, - default=10, - help="Patience in terms of epochs to earlystop the procedure.", - ) - parser.add_argument( - "--batchsize", - type=int, - required=False, - default=2, - help="BatchSize for training.", - ) - parser.add_argument( - "--workers", type=int, required=True, help="Number of CPUs used." - ) - parser.add_argument( - "--gpu", - nargs="+", - type=int, - default=None, - help="IDs of the GPU(s) used.", - ) - parser.add_argument( - "--step", - type=int, - required=False, - default=1, - help="Step for frame to be used.", - ) - parser.add_argument( - "--detect_model", - type=Path, - required=False, - default=None, - help="Model to be used for predictions.", - ) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - input_video = args.input - output_project = args.output - video = dynsight.vision.Video(input_video) - detection = dynsight.vision.Detect( - input_video=video, - project_folder=output_project, - ) - if args.prepare: - detection.synthesize() - if args.train: - detection.fit( - initial_dataset=detection.get_project_path() - / "training_options.yaml", - max_sessions=args.maxcycle, - training_epochs=args.epochs, - training_patience=args.patience, - batch_size=args.batchsize, - workers=args.workers, - device=args.gpu, - frame_reading_step=args.step, - ) - if args.predict: - detection.predict_frames(model_path=args.detect_model) - detection.compute_xyz( - prediction_folder_path=Path("prediction"), - output_path=Path.cwd(), - ) - - -if __name__ == "__main__": - main() diff --git a/examples/video_to_trajectory.py b/examples/video_to_trajectory.py new file mode 100644 index 00000000..97d0dfaf --- /dev/null +++ b/examples/video_to_trajectory.py @@ -0,0 +1,86 @@ +"""Virtualization of a video in order to extract trajectory information. + +In this example, we show how to use the dynsight ``vision`` module +combined with the track module to obtain a trajectory file from a video. + +A very simple video is used to demonstrate the workflow, for this reason the +default detection model is used. For more complex videos, it is recommended to +exploit the ``label_tool`` to create a synthesized dataset as a starting point +to train the initial model and then follow the subsequent steps. +""" + +from __future__ import annotations + +from pathlib import Path + +import matplotlib.pyplot as plt + +from dynsight.track import track_xyz +from dynsight.vision import VisionInstance + + +def plot_results( + instance: VisionInstance, + output_path: Path, + name: str, +) -> None: + if instance.prediction_results is None: + msg = "No prediction results found" + raise ValueError(msg) + + n_detections = [len(result) for result in instance.prediction_results] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) + + ax1.plot(n_detections, marker="o") + ax1.set_title("N° Detections in Time") + ax1.set_xlabel("Frame") + ax1.set_ylabel("N° Detections") + + ax2.hist(n_detections, bins="auto", edgecolor="black") + ax2.set_title("Detection Distribution") + ax2.set_xlabel("N° Detections") + ax2.set_ylabel("Frequency") + + plt.tight_layout() + plt.savefig(output_path / f"{name}.png", dpi=600) + plt.close() + + +def main() -> None: + video_path = Path("video_to_trajectory/example_video.mp4") + n_iterations = 5 + instance = VisionInstance( + source=video_path, + output_path=Path("output"), + device="0", # select GPU id, "cpu" or "mps" for MacOS users. + workers=8, # number of cores used. + ) + for it in range(n_iterations): + instance.predict(prediction_title=f"prediction_{it}") + plot_results( + instance=instance, + output_path=instance.output_path, + name=f"results_plot_{it}", + ) + instance.create_dataset_from_predictions( + dataset_name=f"dataset_{it}", + ) + instance.set_training_dataset( + training_data_yaml=instance.output_path + / f"dataset_{it}" + / "dataset.yaml", + ) + instance.train(title=f"train_{it}") + traj_path = instance.export_prediction_to_xyz( + file_name=Path("trajectory.xyz") + ) + track_xyz( + input_xyz=traj_path, + output_xyz=Path("output/tracked_traj.xyz"), + search_range=10, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/video_to_trajectory/example_video.mp4 b/examples/video_to_trajectory/example_video.mp4 new file mode 100644 index 00000000..9c05cb4e Binary files /dev/null and b/examples/video_to_trajectory/example_video.mp4 differ diff --git a/pyproject.toml b/pyproject.toml index 801d3f0d..763cb51f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,17 @@ maintainers = [ { name = "Simone Martino", email = "s.martino0898@gmail.com" }, ] -dependencies = ["numpy", "dscribe", "tropea-clustering", "MDAnalysis", "deeptime"] +dependencies = [ + "numpy", + "dscribe", + "tropea-clustering", + "MDAnalysis", + "PyYAML", + "trackpy", + "ultralytics", + "deeptime", +] + # Set by cpctools. requires-python = ">=3.8" dynamic = ["version"] @@ -35,6 +45,9 @@ dev = [ github = "https://github.com/GMPavanLab/dynsight" documentation = "https://dynsight.readthedocs.io/en/latest/" +[project.scripts] +label_tool = "dynsight.vision:label_tool" + [tool.setuptools_scm] [tool.ruff] @@ -112,7 +125,9 @@ module = [ 'ultralytics.*', 'opencv-python.*', 'cv2.*', + 'pandas.*', + 'trackpy.*', 'deeptime.*', - 'sklearn.decomposition.*' + 'sklearn.decomposition.*', ] ignore_missing_imports = true diff --git a/src/dynsight/__init__.py b/src/dynsight/__init__.py index 5d0f9ee0..f8663e1a 100644 --- a/src/dynsight/__init__.py +++ b/src/dynsight/__init__.py @@ -9,6 +9,7 @@ logs, onion, soap, + track, trajectory, utilities, vision, @@ -23,6 +24,7 @@ "logs", "onion", "soap", + "track", "trajectory", "utilities", "vision", diff --git a/src/dynsight/_internal/track/__init__.py b/src/dynsight/_internal/track/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dynsight/_internal/track/track.py b/src/dynsight/_internal/track/track.py new file mode 100644 index 00000000..8927240d --- /dev/null +++ b/src/dynsight/_internal/track/track.py @@ -0,0 +1,186 @@ +"""dynsight.track module for particle tracking from an .xyz file.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import pandas as pd +import trackpy as tp + +from dynsight.trajectory import Trj +from dynsight.utilities import read_xyz + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(message)s", +) +logger = logging.getLogger(__name__) + + +def track_xyz( + input_xyz: Path, + output_xyz: Path, + search_range: float, + memory: int = 1, + adaptive_stop: None | float = 0.95, + adaptive_step: None | float = 0.5, +) -> Trj: + """Track particles from an ``.xyz`` file and write a new file with IDs. + + The input ``.xyz`` is assumed to contain only raw 3D coordinates + (without atom labels/identity), and each frame begins with a line + indicating the number of objects, followed by a comment line, then a list + of positions. Each frame in the input file must follow this structure:: + + + comment line + + + ... + + + or:: + + + comment line + + + ... + + + The output file will have the same structure, but each line will start + with the tracked particle ID. + + Parameters: + input_xyz: + Path to the input .xyz file containing positions only. + + output_xyz: + Path where the output .xyz file with particle IDs will be saved. + + search_range: + The maximum allowable displacement of objects between frames for + them to be considered the same particle. Units depend on the + coordinate system used in the input file. If the file comes from + vision, then the unit is ``pixels``. We do not provide a default + parameter here, because fine tuning is required to get good + behaviour. We recommend starting with a value around 2-3 times the + diameter of the particles. But if you are unsure, start with a + value of 10. Additionally, test on a small trajectory to start + with. + + memory: + The maximum number of frames during which an object can vanish, + then re-appear nearby, and be considered the same particle. + + adaptive_stop: + If not `None`, when encountering a region with too many candidate + links (subnet), retry by progressively reducing `search_range` + until the subnet is solvable. If `search_range` becomes less or + equal than the `adaptive_stop`, give up and raise a + `SubnetOversizeException`. + + adaptive_step: + Factor by which the `search_range` is multiplied to reduce it + during adaptive search. Effective only if `adaptive_stop` is not + `None`. + """ + if adaptive_stop is None and adaptive_step is not None: + msg = "adaptive_step is set but adaptive_stop is None." + raise ValueError(msg) + if adaptive_stop is not None and adaptive_step is None: + msg = "adaptive_stop is set but adaptive_step is None." + raise ValueError(msg) + + input_xyz = Path(input_xyz) + output_xyz = Path(output_xyz) + + if not input_xyz.exists(): + msg = f"Input file not found: {input_xyz}" + raise FileNotFoundError(msg) + + positions = read_xyz( + input_xyz=input_xyz, cols_order=["name", "x", "y", "z"] + ) + + if not {"frame", "x", "y", "z"}.issubset(positions.columns): + msg = ( + "Error in the .xyz format. Each line must be " + " or ." + ) + raise ValueError(msg) + + linked = tp.link_df( + positions, + search_range=search_range, + memory=memory, + adaptive_step=adaptive_step, + adaptive_stop=adaptive_stop, + ) + + with output_xyz.open("w") as f: + for frame_num in sorted(linked["frame"].unique()): + frame_data = linked[linked["frame"] == frame_num].sort_values( + "particle" + ) + f.write(f"{len(frame_data)}\n") + f.write(f"Frame {frame_num}\n") + for _, row in frame_data.iterrows(): + pid = int(row["particle"]) + x, y, z = row["x"], row["y"], row["z"] + name = row.get("name") + if name is not None and pd.notna(name): + f.write(f"{name} {x:.6f} {y:.6f} {z:.6f} {pid}\n") + else: + f.write(f"{x:.6f} {y:.6f} {z:.6f} {pid}\n") + + logger.info(f"Linked .xyz file written to: {output_xyz}") + return Trj.init_from_xyz(traj_file=output_xyz, dt=1) + + +def _collect_positions(input_xyz: Path) -> pd.DataFrame: + """Read the xyz file and return the positions dataset at each frame.""" + lines = input_xyz.read_text().splitlines() + + data: list[dict[str, object]] = [] + frame = -1 + row = 0 + dimensions = 3 + for _ in range(len(lines)): + if row >= len(lines): + break + if lines[row].strip().isdigit(): + num_atoms = int(lines[row]) + frame += 1 + row += 2 # skip comment line. + for a in range(num_atoms): + if row + a >= len(lines): + break + parts = lines[row + a].strip().split() + if len(parts) == dimensions: + x, y, z = map(float, parts[0:3]) + data.append({"frame": frame, "x": x, "y": y, "z": z}) + elif len(parts) > dimensions: + name = parts[0] + x, y, z = map(float, parts[1:4]) + data.append( + { + "frame": frame, + "name": name, + "x": x, + "y": y, + "z": z, + } + ) + else: + msg = ( + "Invalid line format, expected 3 or 4 columns, " + f"found {len(parts)}" + ) + raise ValueError(msg) + row += num_atoms + else: + row += 1 + + return pd.DataFrame(data) diff --git a/src/dynsight/_internal/utilities/utilities.py b/src/dynsight/_internal/utilities/utilities.py index 8d521bfc..19f3b45f 100644 --- a/src/dynsight/_internal/utilities/utilities.py +++ b/src/dynsight/_internal/utilities/utilities.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal - -if TYPE_CHECKING: - from pathlib import Path +from pathlib import Path +from typing import Callable, Literal, Mapping, Sequence import numpy as np import numpy.typing as npt +import pandas as pd from scipy.signal import find_peaks from dynsight.trajectory import Insight, Trj @@ -102,3 +101,75 @@ def load_or_compute_soap( soap.dump_to_json(soap_path) return soap + + +Col = Literal["name", "x", "y", "z", "ID"] + + +def _entry_from_parts( + parts: Sequence[str], + cols_order: Sequence[Col], + frame: int, +) -> dict[str, object]: + converters: Mapping[Col, Callable[[str], object]] = { + "name": str, + "x": float, + "y": float, + "z": float, + "ID": int, + } + entry: dict[str, object] = {"frame": frame} + for c, col in enumerate(cols_order): + entry[col] = converters[col](parts[c]) + return entry + + +def read_xyz( + input_xyz: Path | str, + cols_order: Sequence[Col], +) -> pd.DataFrame: + """Read an XYZ trajectory file into a pandas DataFrame. + + The function parses a file in extended XYZ format where each frame begins + with a line containing the number of atoms, followed by a comment/title + line, and then one line per atom containing at least one of the columns + specified in `cols_order`, following the correct order in the file. + + Parameters: + input_xyz : + Path to the XYZ file to read. + cols_order : + The expected column order for each atom line (e.g., + ["name", "x", "y", "z", "ID"]). + + Returns: + A DataFrame containing all parsed atomic entries. Each row corresponds + to one atom in one frame, with columns given by `cols_order` plus the + current frame indexing. + """ + lines = Path(input_xyz).read_text().splitlines() + data: list[dict[str, object]] = [] + + frame = -1 + row = 0 + nlines = len(lines) + + while row < nlines: + token = lines[row].strip() + if token.isdigit(): + n_atoms = int(token) + frame += 1 + row += 2 # skip comment/title line + + end = min(row + n_atoms, nlines) + for i in range(row, end): + parts = lines[i].split() + if len(parts) < len(cols_order): + continue + data.append(_entry_from_parts(parts, cols_order, frame)) + + row += n_atoms + else: + row += 1 + + return pd.DataFrame(data) diff --git a/src/dynsight/_internal/vision/detect.py b/src/dynsight/_internal/vision/detect.py deleted file mode 100644 index 3981486e..00000000 --- a/src/dynsight/_internal/vision/detect.py +++ /dev/null @@ -1,847 +0,0 @@ -from __future__ import annotations - -import logging -import pathlib -import shutil -import tkinter as tk -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast - -import numpy as np -from PIL import Image - -try: - from ultralytics import YOLO -except ImportError: - YOLO = None - -try: - import yaml -except ImportError: - yaml = None - -from .vision_gui import VisionGUI -from .vision_utilities import find_outliers - -if TYPE_CHECKING: - from .video_to_frame import Video - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s | %(levelname)s | %(message)s", -) -logger = logging.getLogger(__name__) - - -class YAMLConfig(TypedDict): - """Class for YAML file configurations.""" - - train: list[str] - val: list[str] - nc: int - names: list[str] - - -class Detect: - """A class to manage the full pipeline of object detection from videos. - - * Author: Simone Martino - - .. caution:: - This part of the code is still under development and may - contain errors. - - """ - - def __init__( - self, - input_video: Video, - project_folder: pathlib.Path = Path("output_folder"), - ) -> None: - """Initialize a detection project. - - Creates all necessary subdirectories and extracts video frames if not - already present. - - Parameters: - input_video: - The input video object from which to extract frames. - - project_folder: - Path to the main project directory. - """ - # Main directory for all the project outputs. - self._project_folder = project_folder - self._project_folder.mkdir(parents=True, exist_ok=True) - - # Directory where store extracted video frames. - self._frame_path = self._project_folder / "frames" - - # Directory path where save the training item crops selected. - self._training_items_path = self._project_folder / "training_items" - - # Directory path for the generated synthetic dataset. - self._syn_dataset_path = self._project_folder / "synthetic_dataset" - - # Directory path where save trained models. - self._models_path = self._project_folder / "models" - - # Directory path where save prediction outputs. - self._predictions_path = self._project_folder / "predictions" - - # Path to the YAML configuration file for dataset training. - self._yaml_file_path = self._project_folder / "training_options.yaml" - - # Resolution of the input video (width, height) in pixels. - self._video_size = input_video.resolution() - # Total number of frames extracted from the video. - self._n_frames = input_video.count_frames() - - # Check if the video's frame are already present - # if not -> extract them - if not self._frame_path.exists(): - input_video.extract_frames(project_folder) - - def get_project_path(self) -> pathlib.Path: - """It returns the path of the detection project.""" - return self._project_folder - - def synthesize( - self, - dataset_dimension: int = 1000, - reference_img_path: None | pathlib.Path = None, - validation_set_fraction: float = 0.2, - collage_size: None | tuple[int, int] = None, - collage_max_repeats: int = 30, - sample_from: Literal["gui"] = "gui", - random_seed: int | None = None, - ) -> None: - """Generate a synthetic dataset by creating collages of training items. - - Parameters: - dataset_dimension: - Total number of synthetic samples to generate. - - reference_img_path: - Path to the reference image to initialize the GUI. - If None, uses the first extracted frame. - - validation_set_fraction: - Fraction of samples to allocate to the validation set. - - collage_size: - Size (width, height) of the generated collage images. - If None the input_video size will be taken. - - collage_max_repeats: - Maximum number of training items that can be placed in - a single collage. - - sample_from: - Mode to collect training items. in the current version - only the "gui" mode is available, other modes will come - in the future. - - random_seed: - Seed for shuffling the dataset. - """ - if collage_size is None: - collage_size = self._video_size - # Dataset structure - images_train_dir = self._syn_dataset_path / "images" / "train" - images_val_dir = self._syn_dataset_path / "images" / "val" - labels_train_dir = self._syn_dataset_path / "labels" / "train" - labels_val_dir = self._syn_dataset_path / "labels" / "val" - - # Sampling method - self._sample( - sample_from=sample_from, - reference_img_path=reference_img_path, - ) - - logger.info("Initializing the synthetic dataset") - # Initialize a new dataset - self._syn_dataset_path.mkdir(exist_ok=True) - images_train_dir.mkdir(exist_ok=True, parents=True) - images_val_dir.mkdir(exist_ok=True, parents=True) - labels_train_dir.mkdir(exist_ok=True, parents=True) - labels_val_dir.mkdir(exist_ok=True, parents=True) - - # Split between training and validation set - num_val = int(dataset_dimension * validation_set_fraction) - num_train = dataset_dimension - num_val - remaining = dataset_dimension - (num_train + num_val) - num_train += remaining - - assignments = ["train"] * num_train + ["val"] * num_val - - rng = np.random.default_rng(seed=random_seed) - rng.shuffle(assignments) - - # Create synthetic images to fill the dataset - # Create labels for each images generated - logger.info("Generating synthetic images") - for i in range(1, dataset_dimension + 1): - collage, label_lines = self._create_collage( - images_folder=self._training_items_path, - width=collage_size[0], - height=collage_size[1], - max_repeats=collage_max_repeats, - random_seed=random_seed, - ) - subset = assignments[i - 1] - if subset == "train": - image_save_path = images_train_dir / f"{i}.png" - label_save_path = labels_train_dir / f"{i}.txt" - else: - image_save_path = images_val_dir / f"{i}.png" - label_save_path = labels_val_dir / f"{i}.txt" - - collage.save(image_save_path) - with label_save_path.open("w") as f: - for line in label_lines: - f.write(line + "\n") - - logger.info("Generating yaml configuration file") - # Generate the config file for the dataset created - self._add_or_create_yaml(self._syn_dataset_path) - - # Just a bridge to the YOLO library - def train( - self, - yaml_file: pathlib.Path, - batch_size: int, - workers: int, - initial_model: str | pathlib.Path = "yolo12x.pt", - training_name: str | None = None, - training_epochs: int = 100, - training_patience: int = 100, - device: int | str | list[int] | None = None, - ) -> None: - """Train a YOLO model on the selected dataset. - - This function uses the - `ultralytics YOLO library `_ - for the model training. - - Parameters: - yaml_file: - Path to the dataset YAML configuration file. - - initial_model: - Initial pretrained model to fine-tune. - - training_name: - Name for the training run. - - training_epochs: - Maximum number of training epochs for each training session. - - training_patience: - Early stopping patience (number of epochs without improvement). - - batch_size: - Batch size for training. - - workers: - Number of dataloader worker threads. - - device: - Device(s) on which run training. - """ - model = YOLO(initial_model) - model.train( - data=yaml_file, - epochs=training_epochs, - patience=training_patience, - batch=batch_size, - imgsz=self._video_size, - workers=workers, - project=self._models_path, - name=training_name, - device=device, - plots=False, - ) - - # Just a bridge to the YOLO library - def predict_frames( - self, - model_path: str | pathlib.Path, - detections_iou: float = 0.1, - prediction_name: str = "prediction", - ) -> None: - """Perform object detection predictions on the extracted frames. - - This function uses the - `ultralytics YOLO library `_ - to detect objects in videos. - - Parameters: - model_path: - Path to the trained model. - - detections_iou: - IOU threshold for object detection filtering. - - prediction_name: - Name under which save the prediction results. - """ - model = YOLO(model_path) - for frame in range(self._n_frames): - model.predict( - project=self._project_folder, - source=self._frame_path / f"{frame}.png", - name=prediction_name, - augment=True, - line_width=2, - save=True, - show_labels=False, - save_txt=True, - save_conf=True, - iou=detections_iou, - max_det=20000, - exist_ok=True, - imgsz=self._video_size, - ) - - def fit( - self, - initial_dataset: pathlib.Path, - max_sessions: int, - training_epochs: int, - training_patience: int, - batch_size: int, - workers: int, - initial_model: str | pathlib.Path = "yolo12x.pt", - device: int | str | list[int] | None = None, - frame_reading_step: int = 1, - ) -> None: - """Train an object detection model through iterative self-training. - - This method performs multiple rounds of training and prediction: - 1. Train the model on the initial dataset. - 2. Predict bounding boxes on video frames using the trained model. - 3. Identify and remove outlier detections based on box sizes. - 4. Build a new training dataset from filtered detections. - 5. Retrain the model on the refined dataset. - 6. Repeat steps 2 to 5 for a given number of sessions to - progressively refine the model. - - This method uses the - `ultralytics YOLO library `_. - - Parameters: - initial_dataset: - Path to the initial dataset YAML file. - - initial_model: - Path to the initial model weights (.pt file). - - max_sessions: - Number of retraining cycles. - - training_epochs: - Number of epochs per training session. - - training_patience: - Early stopping patience during training. - - batch_size: - Batch size for training. - - workers: - Number of data loader workers. - - device: - Device(s) to use (e.g., "cpu", "0", [0,1]). - - frame_reading_step: - Specifies the interval at which frames are sampled from - the video during processing. - """ - # Initilize the first training - current_dataset = initial_dataset - guess_model_name = "v0" - prediction_number = 0 - detection_results = [] - - logger.info("First training begins.") - self.train( - yaml_file=current_dataset, - initial_model=initial_model, - training_epochs=2, - batch_size=batch_size, - workers=workers, - device=device, - training_name=guess_model_name, - ) - - current_model_path = ( - self._project_folder - / "models" - / guess_model_name - / "weights" - / "best.pt" - ) - current_model = YOLO(current_model_path) - logger.info(f"Starting prediction number {prediction_number}") - for f in range(0, self._n_frames, frame_reading_step): - frame_file = self._frame_path / f"{f}.png" - prediction = current_model.predict( - source=frame_file, - imgsz=self._video_size, - augment=True, - save=True, - save_txt=True, - save_conf=True, - show_labels=False, - name=f"attempt_{prediction_number}", - iou=0.1, - max_det=20000, - project=self._predictions_path, - line_width=2, - exist_ok=True, - ) - # Read and save the prediction results - if prediction and prediction[0].boxes: - xywh = prediction[0].boxes.xywh.cpu().numpy() - conf = prediction[0].boxes.conf.cpu().numpy() - cls = prediction[0].boxes.cls.cpu().numpy() - n_detection = len(xywh) - - for i in range(n_detection): - x, y, w, h = xywh[i] - detection_results.append( - { - "frame": f, - "class_id": int(cls[i]), - "center_x": float(x), - "center_y": float(y), - "width": float(w), - "height": float(h), - "confidence": float(conf[i]), - } - ) - logger.info("Looking for outliers") - # Filter detections - detection_results = self._filter_detections( - detection_results, - prediction_number, - ) - # Build a new dataset based on the "filtered" detection results - # New dataset path - train_dataset_path = ( - self._project_folder - / "train_datasets" - / f"dataset_{prediction_number}" - ) - logger.info(f"Building the dataset (version {prediction_number})") - # Build the dataset - self._build_dataset( - detection_results=detection_results, - dataset_name=f"dataset_{prediction_number}", - ) - # Add the new dataset to the training config file - self._add_or_create_yaml(train_dataset_path) - - # Iterative part to improve the model performace - for s in range(max_sessions): - logger.info(f"Starting a new training session (number {s + 1}") - new_model_name = f"v{s + 1}" - self.train( - yaml_file=self._yaml_file_path, - initial_model=current_model_path, - training_epochs=training_epochs, - training_patience=training_patience, - batch_size=batch_size, - workers=workers, - device=device, - training_name=new_model_name, - ) - current_model_path = ( - self._project_folder - / "models" - / new_model_name - / "weights" - / "best.pt" - ) - current_model = YOLO(current_model_path) - prediction_number += 1 - detection_results = [] - logger.info(f"Starting prediction number {prediction_number}") - for f in range(0, self._n_frames, frame_reading_step): - frame_file = self._frame_path / f"{f}.png" - prediction = current_model.predict( - source=frame_file, - imgsz=self._video_size, - augment=True, - save=True, - save_txt=True, - save_conf=True, - show_labels=False, - name=f"attempt_{prediction_number}", - iou=0.1, - max_det=20000, - project=self._project_folder / "predictions", - line_width=2, - exist_ok=True, - ) - # Read prediction - if prediction and prediction[0].boxes: - xywh = prediction[0].boxes.xywh.cpu().numpy() - conf = prediction[0].boxes.conf.cpu().numpy() - cls = prediction[0].boxes.cls.cpu().numpy() - n_detection = len(xywh) - - for i in range(n_detection): - x, y, w, h = xywh[i] - detection_results.append( - { - "frame": f, - "class_id": int(cls[i]), - "center_x": float(x), - "center_y": float(y), - "width": float(w), - "height": float(h), - "confidence": float(conf[i]), - } - ) - logger.info("Looking for outliers") - # Filter detections - detection_results = self._filter_detections( - detection_results, - prediction_number, - ) - logger.info(f"Building the dataset (version {prediction_number}") - # Build the new dataset - self._build_dataset( - detection_results=detection_results, - dataset_name=f"dataset_{prediction_number}", - ) - # Remove the oldest dataset in config - # It has been made to avoid training bias on worst results - self._remove_old_dataset() - - train_dataset_path = ( - self._project_folder - / "train_datasets" - / f"dataset_{prediction_number}" - ) - # Update the dataset config file - self._add_or_create_yaml(train_dataset_path) - - def compute_xyz( - self, - prediction_folder_path: pathlib.Path, - output_path: pathlib.Path, - ) -> None: - """Computes and saves the trajectory of detections to xyz file. - - Parameters: - prediction_folder_path: - The path to the folder containing detection data files. - output_path: - The path where the resulting trajectory data should be saved. - """ - lab_folder = prediction_folder_path / "labels" - frame_positions = [] - for frame in range(self._n_frames): - label_file = lab_folder / f"{frame}.txt" - - with label_file.open("r") as file: - frame_detections = [] - for line in file: - values = line.strip().split() - _, x, y, width, height, confidence = map(float, values) - x *= self._video_size[0] - y *= self._video_size[1] - frame_detections.append((x, y)) - frame_positions.append(frame_detections) - - with output_path.open("w") as file: - for frame_index, detections in enumerate(frame_positions): - file.write(f"{len(detections)}\n") - file.write(f"Frame {frame_index}\n") - for x, y in detections: - z = 0 - file.write(f"{x:.6f} {y:.6f} {z:.6f}\n") - - def _filter_detections( - self, - input_results: list[dict[str, int | float]], - prediction_number: int, - ) -> list[dict[str, int | float]]: - # Initialize outliers folder in the prediction folder - outliers_plt_folder = ( - self._project_folder - / "predictions" - / f"attempt_{prediction_number}" - / "outliers" - ) - outliers_plt_folder.mkdir(exist_ok=True) - - # Look for outliers in the boxes width and height - widths = np.array([d["width"] for d in input_results], dtype=float) - heights = np.array( - [d["height"] for d in input_results], - dtype=float, - ) - try: - out_width = set( - find_outliers( - distribution=widths, - save_path=outliers_plt_folder, - fig_name="width", - ) - ) - except (RuntimeError, ValueError) as e: - logger.warning( - "Outlier detection for width failed: %s. No width outliers", - e, - ) - out_width = set() - - try: - out_height = set( - find_outliers( - distribution=heights, - save_path=outliers_plt_folder, - fig_name="height", - ) - ) - except (RuntimeError, ValueError) as e: - logger.warning( - "Outlier detection for height failed: %s. No height outliers.", - e, - ) - out_height = set() - - filtered_results = [ - det - for det in input_results - if (det["width"] not in out_width) - and (det["height"] not in out_height) - ] - - if not filtered_results: - logger.warning("No outliers detected.") - filtered_results = input_results - - return filtered_results - - def _remove_old_dataset(self) -> None: - """Removes the oldest dataset from the YAML configuration.""" - yaml_path = self._yaml_file_path - - if not yaml_path.exists(): - return - - with yaml_path.open("r") as f: - cfg = yaml.safe_load(f) or {} - - for key in ("train", "val"): - if key not in cfg: - return - if isinstance(cfg[key], str): - cfg[key] = [cfg[key]] - elif not isinstance(cfg[key], list): - return - - for key in ("train", "val"): - if cfg[key]: - cfg[key].pop(0) - - with yaml_path.open("w") as f: - yaml.safe_dump(cfg, f, sort_keys=False) - - def _add_or_create_yaml(self, new_dataset_path: Path) -> None: - yaml_path = Path(self._yaml_file_path) - - train_p = str((new_dataset_path / "images/train").resolve()) - val_p = str((new_dataset_path / "images/val").resolve()) - - if not yaml_path.exists(): - cfg: YAMLConfig = { - "train": [train_p], - "val": [val_p], - "nc": 1, - "names": ["obj"], - } - else: - raw = yaml.safe_load(yaml_path.open("r")) or {} - cfg = cast("YAMLConfig", raw) - - if not isinstance(cfg.get("train"), list): - cfg["train"] = [str(cfg.get("train") or train_p)] - if not isinstance(cfg.get("val"), list): - cfg["val"] = [str(cfg.get("val") or val_p)] - - cfg["train"].append(train_p) - cfg["val"].append(val_p) - cfg["train"] = list(dict.fromkeys(cfg["train"])) - cfg["val"] = list(dict.fromkeys(cfg["val"])) - # cfg.pop("path", None) # noqa: ERA001 - - if isinstance(cfg.get("nc"), str): - cfg["nc"] = int(cfg["nc"]) - - with yaml_path.open("w") as f: - yaml.safe_dump(cfg, f, sort_keys=False) - - def _build_dataset( - self, - detection_results: list[dict[str, Any]], - dataset_name: str, - split_ratio: float = 0.8, - ) -> None: - """Builds a dataset by splitting frames/labels into train and val.""" - output_dir = self._project_folder / "train_datasets" / dataset_name - imgs_train_dir = output_dir / "images" / "train" - imgs_val_dir = output_dir / "images" / "val" - labs_train_dir = output_dir / "labels" / "train" - labs_val_dir = output_dir / "labels" / "val" - for d in (imgs_train_dir, imgs_val_dir, labs_train_dir, labs_val_dir): - d.mkdir(parents=True, exist_ok=True) - - detections_by_frame: dict[int, list[dict[str, Any]]] = {} - for det in detection_results: - frame_idx = det["frame"] - detections_by_frame.setdefault(frame_idx, []).append(det) - - all_frames = sorted(detections_by_frame.keys()) - - split_point = int(len(all_frames) * split_ratio) - train_frames = set(all_frames[:split_point]) - - for frame_idx, dets in detections_by_frame.items(): - if frame_idx in train_frames: - img_dest = imgs_train_dir - lab_dest = labs_train_dir - else: - img_dest = imgs_val_dir - lab_dest = labs_val_dir - - img_src = self._frame_path / f"{frame_idx}.png" - img_dst = img_dest / f"{frame_idx}.png" - shutil.copy(img_src, img_dst) - - lab_file = lab_dest / f"{frame_idx}.txt" - with lab_file.open("w") as f: - img_w, img_h = Image.open(img_src).size - - for det in dets: - x_ctr = det["center_x"] - y_ctr = det["center_y"] - w = det["width"] - h = det["height"] - cls = det["class_id"] - x_ctr_n = x_ctr / img_w - y_ctr_n = y_ctr / img_h - w_n = w / img_w - h_n = h / img_h - f.write( - f"{cls} {x_ctr_n:.6f} {y_ctr_n:.6f} " - f"{w_n:.6f} {h_n:.6f}\n" - ) - - def _create_collage( - self, - images_folder: pathlib.Path, - width: int, - height: int, - random_seed: int | None = None, - patience: int = 1000, - max_repeats: int = 1, - ) -> tuple[Image.Image, list[str]]: - """Creates a collage of images by placing them randomly on a canvas.""" - collage = Image.new("RGBA", (width, height), (255, 255, 255, 255)) - placed_rects: list[tuple[int, int, int, int]] = [] - label_lines = [] - cropped_images = [] - - for file in images_folder.iterdir(): - if file.suffix.lower() in {".png", ".jpg", ".jpeg"}: - img = Image.open(file).convert("RGBA") - cropped_images.append(img) - - if not cropped_images: - msg = "No images found in the specified folder" - raise ValueError(msg) - - total_placement = len(cropped_images) * max_repeats - placed_count = 0 - cropped_images_array = np.array(cropped_images, dtype=object) - rng = np.random.default_rng(seed=random_seed) - for _ in range(total_placement): - cropped = rng.choice(cropped_images_array) - w, h = cropped.size[:2] - max_x = width - w - max_y = height - h - placed = False - - for _ in range(patience): - x = rng.integers(0, max_x + 1) - y = rng.integers(0, max_y + 1) - new_rect = (x, y, x + w, y + h) - - overlap = any( - not ( - new_rect[2] <= rect[0] - or new_rect[0] >= rect[2] - or new_rect[3] <= rect[1] - or new_rect[1] >= rect[3] - ) - for rect in placed_rects - ) - - if not overlap: - collage.paste(cropped, (x, y), cropped) - placed_rects.append(new_rect) - - center_x = (x + w / 2) / width - center_y = (y + h / 2) / height - width_norm = w / width - height_norm = h / height - label_line = ( - f"0 {center_x:.6f} " - f"{center_y:.6f} " - f"{width_norm:.6f} " - f"{height_norm:.6f}" - ) - label_lines.append(label_line) - placed_count += 1 - placed = True - break - - if not placed: - break - - return collage, label_lines - - def _sample( - self, sample_from: str, reference_img_path: pathlib.Path | None - ) -> None: - # GUI mode - if sample_from == "gui": - logger.info("Loading Graphic User Interface") - # If not specified by the user use the first img as example - if reference_img_path is None: - logger.info("Using first frame as reference img.") - reference_img_path = self._frame_path / "0.png" - # Open the GUI - root = tk.Tk() - VisionGUI( - master=root, - image_path=reference_img_path, - destination_folder=self._project_folder, - ) - root.mainloop() - # Check if the training items has been properly created - if not ( - self._training_items_path.exists() - and self._training_items_path.is_dir() - ): - msg = "'training_items' folder not created or not found" - raise ValueError(msg) - else: - raise NotImplementedError diff --git a/src/dynsight/_internal/vision/label_tool.py b/src/dynsight/_internal/vision/label_tool.py new file mode 100644 index 00000000..f3f2d62f --- /dev/null +++ b/src/dynsight/_internal/vision/label_tool.py @@ -0,0 +1,40 @@ +import functools +import logging +import threading +import webbrowser +from http.server import SimpleHTTPRequestHandler +from pathlib import Path +from socketserver import TCPServer + +logger = logging.getLogger(__name__) + + +class HTTPRequestHandler(SimpleHTTPRequestHandler): + def log_message(self, fmt: str, *args: object) -> None: + pass + + # do_POST must be uppercase + def do_POST(self) -> None: + if self.path == "/shutdown": + self.send_response(200) + self.end_headers() + logger.info("Shutdown request received.") + threading.Thread(target=self.server.shutdown).start() + else: + self.send_error(404) + + +class ReusableTCPServer(TCPServer): + allow_reuse_address = True + + +def label_tool(port: int = 8888) -> None: + web_dir = Path(__file__).parent / "label_tool" + handler = functools.partial(HTTPRequestHandler, directory=str(web_dir)) + with ReusableTCPServer(("", port), handler) as httpd: + url = f"http://localhost:{port}/index.html" + logger.info(f"Starting server at {url}") + webbrowser.open(url) + httpd.serve_forever() + httpd.server_close() + logger.info("Server closed.") diff --git a/src/dynsight/_internal/vision/label_tool/index.html b/src/dynsight/_internal/vision/label_tool/index.html new file mode 100644 index 00000000..d7e05bfd --- /dev/null +++ b/src/dynsight/_internal/vision/label_tool/index.html @@ -0,0 +1,57 @@ + + + + + + dynsight labeling tool + + + + + +
+
+ + + + + + + + + + +
+
+
+ +
+
+
+
+
+ +
+ + + + + + diff --git a/src/dynsight/_internal/vision/label_tool/script.js b/src/dynsight/_internal/vision/label_tool/script.js new file mode 100644 index 00000000..79b6f9de --- /dev/null +++ b/src/dynsight/_internal/vision/label_tool/script.js @@ -0,0 +1,570 @@ +// script.js + +const imageInput = document.getElementById("imageInput"); +const imageDisplay = document.getElementById("imageDisplay"); +const imageContainer = document.getElementById("imageContainer"); +const imageWrapper = document.getElementById("imageWrapper"); +const labelList = document.getElementById("labelList"); +const addLabelBtn = document.getElementById("addLabelBtn"); +const newLabelInput = document.getElementById("newLabelInput"); +const clearLastBtn = document.getElementById("clearLast"); +const clearAllBtn = document.getElementById("clearAll"); +const exportBtn = document.getElementById("exportYolo"); +const exportAllBtn = document.getElementById("exportAll"); +const synthBtn = document.getElementById("synthesize"); +const nextImageBtn = document.getElementById("nextImage"); +const prevImageBtn = document.getElementById("prevImage"); +const verticalLine = document.getElementById("verticalLine"); +const horizontalLine = document.getElementById("horizontalLine"); +const zoomSlider = document.getElementById("zoomSlider"); + +let zoomLevel = 1; +let baseZoom = 1; +let naturalWidth = 0; +let naturalHeight = 0; + +const overlay = document.getElementById("overlay"); + +verticalLine.style.display = "none"; +horizontalLine.style.display = "none"; + +imageContainer.onmouseenter = () => { + verticalLine.style.display = "block"; + horizontalLine.style.display = "block"; +}; + +imageContainer.onmouseleave = () => { + verticalLine.style.display = "none"; + horizontalLine.style.display = "none"; +}; + +let currentLabel = null; +const labelColors = {}; +let isDrawing = false; +let startX, + startY, + box = null; + +let images = []; +let currentIndex = 0; +const annotations = {}; // imageName -> [boxData] + +function getRandomColor() { + const hue = Math.floor(Math.random() * 360); + return `hsl(${hue}, 90%, 50%)`; +} + +function setActiveLabel(item) { + document + .querySelectorAll(".label-item") + .forEach((i) => i.classList.remove("active")); + item.classList.add("active"); + currentLabel = item.textContent; +} + +function createLabelItem(text) { + const item = document.createElement("div"); + item.className = "label-item"; + item.textContent = text; + labelColors[text] = labelColors[text] || getRandomColor(); + item.style.backgroundColor = labelColors[text]; + item.style.color = "#fff"; + item.addEventListener("click", () => setActiveLabel(item)); + labelList.appendChild(item); +} + +addLabelBtn.onclick = () => { + const label = newLabelInput.value.trim(); + if (label && !labelColors[label]) { + createLabelItem(label); + newLabelInput.value = ""; + } +}; + +imageInput.onchange = (e) => { + images = Array.from(e.target.files); + currentIndex = 0; + loadImage(currentIndex); +}; + +function loadImage(index) { + if (!images[index]) return; + const url = URL.createObjectURL(images[index]); + imageDisplay.onload = () => { + const iw = imageDisplay.naturalWidth; + const ih = imageDisplay.naturalHeight; + imageDisplay.style.width = `${iw}px`; + imageDisplay.style.height = `${ih}px`; + naturalWidth = iw; + naturalHeight = ih; + baseZoom = Math.min( + imageContainer.clientWidth / iw, + imageContainer.clientHeight / ih, + 1, + ); + zoomLevel = baseZoom; + zoomSlider.value = 100; + updateTransform(); + const name = images[index].name; + if (!annotations[name]) annotations[name] = []; + clearBoxes(); + annotations[name].forEach(addBoxFromData); + }; + imageDisplay.src = url; +} + +zoomSlider.oninput = (e) => { + zoomLevel = (e.target.value / 100) * baseZoom; + updateTransform(); + clearBoxes(); + annotations[images[currentIndex].name].forEach(addBoxFromData); +}; + +function clearBoxes() { + overlay.innerHTML = ""; +} + +function updateTransform() { + const w = naturalWidth * zoomLevel; + const h = naturalHeight * zoomLevel; + imageDisplay.style.width = `${w}px`; + imageDisplay.style.height = `${h}px`; + imageWrapper.style.width = `${w}px`; + imageWrapper.style.height = `${h}px`; + overlay.style.width = `${w}px`; + overlay.style.height = `${h}px`; +} + +function addBoxFromData(data) { + const box = document.createElement("div"); + box.className = "bounding-box"; + box.style.left = `${data.left * zoomLevel}px`; + box.style.top = `${data.top * zoomLevel}px`; + box.style.width = `${data.width * zoomLevel}px`; + box.style.height = `${data.height * zoomLevel}px`; + box.style.border = `2px dashed ${labelColors[data.label]}`; + box.style.backgroundColor = labelColors[data.label] + .replace("hsl", "hsla") + .replace(")", ", 0.1)"); + + const tag = document.createElement("div"); + tag.className = "label-tag"; + tag.textContent = data.label; + tag.style.backgroundColor = labelColors[data.label]; + box.appendChild(tag); + + overlay.appendChild(box); +} + +imageContainer.onmousedown = (e) => { + if (e.button !== 0) { + return; + } + + if (!currentLabel || !images[currentIndex]) return; + + const rect = imageDisplay.getBoundingClientRect(); + startX = (e.clientX - rect.left) / zoomLevel; + startY = (e.clientY - rect.top) / zoomLevel; + + // Ignore clicks started outside the image boundaries + if ( + startX < 0 || + startY < 0 || + startX > naturalWidth || + startY > naturalHeight + ) { + isDrawing = false; + return; + } + + box = document.createElement("div"); + box.className = "bounding-box"; + box.style.left = `${startX * zoomLevel}px`; + box.style.top = `${startY * zoomLevel}px`; + box.style.border = `2px dashed ${labelColors[currentLabel]}`; + box.style.backgroundColor = labelColors[currentLabel] + .replace("hsl", "hsla") + .replace(")", ", 0.1)"); + + const tag = document.createElement("div"); + tag.className = "label-tag"; + tag.textContent = currentLabel; + tag.style.backgroundColor = labelColors[currentLabel]; + box.appendChild(tag); + + overlay.appendChild(box); + isDrawing = true; +}; + +imageContainer.onmousemove = (e) => { + const imgRect = imageDisplay.getBoundingClientRect(); + const containerRect = imageContainer.getBoundingClientRect(); + + const currX = (e.clientX - imgRect.left) / zoomLevel; + const currY = (e.clientY - imgRect.top) / zoomLevel; + const clampedX = Math.max(0, Math.min(naturalWidth, currX)); + const clampedY = Math.max(0, Math.min(naturalHeight, currY)); + + verticalLine.style.left = `${ + e.clientX - containerRect.left + imageContainer.scrollLeft + }px`; + horizontalLine.style.top = `${ + e.clientY - containerRect.top + imageContainer.scrollTop + }px`; + + if (!isDrawing || !box) return; + + box.style.left = `${Math.min(clampedX, startX) * zoomLevel}px`; + box.style.top = `${Math.min(clampedY, startY) * zoomLevel}px`; + box.style.width = `${Math.abs(clampedX - startX) * zoomLevel}px`; + box.style.height = `${Math.abs(clampedY - startY) * zoomLevel}px`; +}; + +imageContainer.onmouseup = (e) => { + if (!isDrawing || !box) return; + + const imgRect = imageDisplay.getBoundingClientRect(); + const endX = (e.clientX - imgRect.left) / zoomLevel; + const endY = (e.clientY - imgRect.top) / zoomLevel; + const clampedX = Math.max(0, Math.min(naturalWidth, endX)); + const clampedY = Math.max(0, Math.min(naturalHeight, endY)); + + const left = Math.min(startX, clampedX); + const top = Math.min(startY, clampedY); + const width = Math.abs(clampedX - startX); + const height = Math.abs(clampedY - startY); + + annotations[images[currentIndex].name].push({ + label: currentLabel, + left, + top, + width, + height, + }); + + box = null; + isDrawing = false; +}; + +clearLastBtn.onclick = () => { + const ann = annotations[images[currentIndex].name]; + if (ann.length > 0) { + ann.pop(); + clearBoxes(); + ann.forEach(addBoxFromData); + } +}; + +clearAllBtn.onclick = () => { + annotations[images[currentIndex].name] = []; + clearBoxes(); +}; + +prevImageBtn.onclick = () => { + if (currentIndex > 0) { + currentIndex--; + loadImage(currentIndex); + } +}; + +nextImageBtn.onclick = () => { + if (currentIndex < images.length - 1) { + currentIndex++; + loadImage(currentIndex); + } +}; + +exportBtn.onclick = () => { + const img = images[currentIndex]; + if (!img) return; + const iw = imageDisplay.naturalWidth; + const ih = imageDisplay.naturalHeight; + const annots = annotations[img.name] || []; + const labelMap = {}; + let nextId = 0; + let txt = ""; + annots.forEach((obj) => { + if (!(obj.label in labelMap)) labelMap[obj.label] = nextId++; + const cx = (obj.left + obj.width / 2) / iw; + const cy = (obj.top + obj.height / 2) / ih; + const w = obj.width / iw; + const h = obj.height / ih; + txt += + labelMap[obj.label] + + " " + + cx.toFixed(6) + + " " + + cy.toFixed(6) + + " " + + w.toFixed(6) + + " " + + h.toFixed(6) + + "\n"; + }); + const blob = new Blob([txt], { type: "text/plain" }); + const a = document.createElement("a"); + a.href = URL.createObjectURL(blob); + a.download = img.name.replace(/\.[^/.]+$/, "") + ".txt"; + a.click(); + URL.revokeObjectURL(a.href); +}; + +exportAllBtn.onclick = async () => { + if (images.length === 0) { + alert("No images uploaded."); + return; + } + let trainPercent = parseFloat( + prompt("Percentage of images for training?", "80"), + ); + if ( + Number.isNaN(trainPercent) || + trainPercent <= 0 || + trainPercent >= 100 + ) { + trainPercent = 80; + } + const numTrain = Math.floor(images.length * (trainPercent / 100)); + const zip = new JSZip(); + const imgTrain = zip.folder("images/train"); + const imgVal = zip.folder("images/val"); + const lblTrain = zip.folder("labels/train"); + const lblVal = zip.folder("labels/val"); + const labelMap = {}; + let nextId = 0; + for (let i = 0; i < images.length; i++) { + const image = images[i]; + const name = image.name; + const imgData = await image.arrayBuffer(); + const imgFolder = i < numTrain ? imgTrain : imgVal; + const lblFolder = i < numTrain ? lblTrain : lblVal; + imgFolder.file(name, imgData); + + const img = new Image(); + const url = URL.createObjectURL(image); + img.src = url; + await new Promise((resolve) => (img.onload = resolve)); + const iw = img.naturalWidth; + const ih = img.naturalHeight; + const annots = annotations[name] || []; + let txt = ""; + annots.forEach((obj) => { + if (!(obj.label in labelMap)) labelMap[obj.label] = nextId++; + const cx = (obj.left + obj.width / 2) / iw; + const cy = (obj.top + obj.height / 2) / ih; + const w = obj.width / iw; + const h = obj.height / ih; + txt += + labelMap[obj.label] + + " " + + cx.toFixed(6) + + " " + + cy.toFixed(6) + + " " + + w.toFixed(6) + + " " + + h.toFixed(6) + + "\n"; + }); + const labelFileName = name.replace(/\.[^/.]+$/, "") + ".txt"; + lblFolder.file(labelFileName, txt); + } + + const names = Object.keys(labelMap); + const yaml = `path: . + train: images/train + val: images/val + nc: ${names.length} + names: [${names.map((n) => `'${n}'`).join(", ")}] + `; + zip.file("dataset.yaml", yaml); + + const content = await zip.generateAsync({ type: "blob" }); + const a = document.createElement("a"); + a.href = URL.createObjectURL(content); + a.download = "yolo_dataset.zip"; + a.click(); + URL.revokeObjectURL(a.href); +}; + +synthBtn.onclick = async () => { + if (images.length === 0) { + alert("No images uploaded."); + return; + } + + const numImages = parseInt( + prompt("Number of synthetic images to generate?", "10"), + 10, + ); + const width = parseInt(prompt("Image width?", "640"), 10); + const height = parseInt(prompt("Image height?", "640"), 10); + + if ( + !numImages || + Number.isNaN(numImages) || + !width || + Number.isNaN(width) || + !height || + Number.isNaN(height) + ) { + alert("Invalid parameters."); + return; + } + + const crops = []; + const labelMap = {}; + let nextId = 0; + for (const file of images) { + const ann = annotations[file.name] || []; + for (const c of ann) { + crops.push({ file, ...c }); + if (!(c.label in labelMap)) labelMap[c.label] = nextId++; + } + } + + if (crops.length === 0) { + alert("No label found."); + return; + } + + async function loadImage(file) { + return await new Promise((resolve) => { + const img = new Image(); + img.src = URL.createObjectURL(file); + img.onload = () => { + URL.revokeObjectURL(img.src); + resolve(img); + }; + }); + } + + function overlaps(x, y, w, h, boxes) { + return boxes.some((b) => { + return !( + x + w <= b.x || + x >= b.x + b.w || + y + h <= b.y || + y >= b.y + b.h + ); + }); + } + + async function createCollage() { + const canvas = document.createElement("canvas"); + canvas.width = width; + canvas.height = height; + const ctx = canvas.getContext("2d"); + ctx.fillStyle = "white"; + ctx.fillRect(0, 0, width, height); + + const placed = []; + const numObj = Math.min(5, crops.length); + for (let i = 0; i < numObj; i++) { + const crop = crops[Math.floor(Math.random() * crops.length)]; + const img = await loadImage(crop.file); + const c = document.createElement("canvas"); + c.width = crop.width; + c.height = crop.height; + c.getContext("2d").drawImage( + img, + crop.left, + crop.top, + crop.width, + crop.height, + 0, + 0, + crop.width, + crop.height, + ); + + const scale = + (0.15 + 0.15 * Math.random()) * + (Math.min(width, height) / Math.max(crop.width, crop.height)); + const w = crop.width * scale; + const h = crop.height * scale; + + let x, y; + let tries = 0; + do { + x = Math.random() * (width - w); + y = Math.random() * (height - h); + tries += 1; + } while (tries < 50 && overlaps(x, y, w, h, placed)); + + if (tries === 50) continue; + + ctx.drawImage(c, 0, 0, crop.width, crop.height, x, y, w, h); + placed.push({ label: crop.label, x, y, w, h }); + } + + let txt = ""; + placed.forEach((p) => { + const cls = labelMap[p.label]; + const cx = (p.x + p.w / 2) / width; + const cy = (p.y + p.h / 2) / height; + const ww = p.w / width; + const hh = p.h / height; + txt += `${cls} ${cx.toFixed(6)} ${cy.toFixed(6)} ${ww.toFixed( + 6, + )} ${hh.toFixed(6)}\n`; + }); + + const blob = await new Promise((resolve) => + canvas.toBlob((b) => resolve(b), "image/jpeg"), + ); + return { blob, txt }; + } + + const zip = new JSZip(); + const imgTrain = zip.folder("images/train"); + const imgVal = zip.folder("images/val"); + const lblTrain = zip.folder("labels/train"); + const lblVal = zip.folder("labels/val"); + + const numTrain = Math.floor(numImages * 0.8); + + for (let i = 0; i < numImages; i++) { + const { blob, txt } = await createCollage(); + const imgName = `synt_${i}.jpg`; + const txtName = `synt_${i}.txt`; + if (i < numTrain) { + imgTrain.file(imgName, blob); + lblTrain.file(txtName, txt); + } else { + imgVal.file(imgName, blob); + lblVal.file(txtName, txt); + } + } + + const names = Object.keys(labelMap); + const yaml = `path: . + train: images/train + val: images/val + nc: ${names.length} + names: [${names.map((n) => `'${n}'`).join(", ")}] + `; + zip.file("dataset.yaml", yaml); + + const content = await zip.generateAsync({ type: "blob" }); + const a = document.createElement("a"); + a.href = URL.createObjectURL(content); + a.download = "synt_dataset.zip"; + a.click(); + URL.revokeObjectURL(a.href); +}; + +let navigatingAway = false; +document.addEventListener("click", (e) => { + const link = e.target.closest("a"); + if (link && link.href) { + navigatingAway = true; + } +}); + +window.addEventListener("pagehide", () => { + if (!navigatingAway) { + navigator.sendBeacon("/shutdown"); + } +}); diff --git a/src/dynsight/_internal/vision/label_tool/styles.css b/src/dynsight/_internal/vision/label_tool/styles.css new file mode 100644 index 00000000..fffb761f --- /dev/null +++ b/src/dynsight/_internal/vision/label_tool/styles.css @@ -0,0 +1,145 @@ +* { + user-select: none; +} +body { + margin: 0; + font-family: Arial, sans-serif; + display: flex; + height: 100vh; +} +.sidebar { + width: 250px; + background: #f4f4f4; + border-right: 1px solid #ccc; + padding: 20px; + box-sizing: border-box; + flex-shrink: 0; +} +.main { + flex: 1; + min-width: 0; + display: flex; + flex-direction: column; + background: #fff; +} +.top-bar { + padding: 10px 20px; + border-bottom: 1px solid #ccc; + background: #fafafa; + display: flex; + gap: 10px; + align-items: center; + flex-wrap: wrap; + flex-shrink: 0; +} +.image-container { + flex: 1; + position: relative; + background: #eee; + overflow: auto; + cursor: crosshair; +} + +.bounding-box { + position: absolute; + pointer-events: none; +} +.label-tag { + position: absolute; + top: -20px; + left: 0; + color: white; + padding: 2px 6px; + font-size: 12px; + font-weight: bold; + border-radius: 3px; + pointer-events: none; +} +.label-item { + background: #e0e0e0; + padding: 5px 10px; + margin-bottom: 5px; + border-radius: 4px; + cursor: pointer; +} +.label-item.active { + outline: 2px solid #2196f3; +} +.upload-btn, +.nav-btn { + margin-top: 5px; + padding: 5px 10px; + font-size: 14px; + cursor: pointer; +} +input[type="text"] { + width: 100%; + padding: 5px; + font-size: 14px; + box-sizing: border-box; +} + +input[type="range"] { + width: 300px; +} +.crosshair-line { + position: absolute; + pointer-events: none; + z-index: 10; +} + +#verticalLine { + width: 3px; + height: 300%; + top: 0; + background-image: repeating-linear-gradient( + to bottom, + rgba(255, 0, 0, 0.6) 0px, + rgba(255, 0, 0, 0.6) 5px, + transparent 5px, + transparent 10px + ); +} + +#horizontalLine { + height: 3px; + width: 300%; + left: 0; + background-image: repeating-linear-gradient( + to right, + rgba(255, 0, 0, 0.6) 0px, + rgba(255, 0, 0, 0.6) 5px, + transparent 5px, + transparent 10px + ); +} +#imageWrapper { + position: relative; + display: inline-block; + max-width: none; + max-height: none; + transform-origin: top left; +} + +#imageDisplay { + display: block; + max-width: none; + height: auto; + pointer-events: none; + user-drag: none; +} + +#overlay { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + pointer-events: none; +} + +.bounding-box { + position: absolute; + pointer-events: auto; + box-sizing: border-box; +} diff --git a/src/dynsight/_internal/vision/video_to_frame.py b/src/dynsight/_internal/vision/video_to_frame.py deleted file mode 100644 index e19409eb..00000000 --- a/src/dynsight/_internal/vision/video_to_frame.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING - -try: - import cv2 -except ImportError: - cv2 = None # type: ignore[assignment] - -if TYPE_CHECKING: - import pathlib - - -@dataclass -class Video: - """Load a video file and provide utilities. - - This class loads a video from a file path and provides methods - to retrieve video information and extract frames. - - * Author: Simone Martino - - .. caution:: - This part of the code is still under development and may - contain errors. - - Parameters: - video_path: Path to the video file. - - """ - - video_path: pathlib.Path - # OpenCV video capture object. - _capture: cv2.VideoCapture = field(init=False, repr=False) - - def __post_init__(self) -> None: - """Load the the video.""" - self._capture = cv2.VideoCapture(str(self.video_path)) - if not self._capture.isOpened(): - msg = f"Impossible to load the video: {self.video_path}" - raise ValueError(msg) - - def __del__(self) -> None: - """Close the the video.""" - if hasattr(self, "_capture") and self._capture.isOpened(): - self._capture.release() - - def count_frames(self) -> int: - """Counts the total number of frames in the video. - - Returns: - The number of frames in the video. - """ - return int(self._capture.get(cv2.CAP_PROP_FRAME_COUNT)) - - def resolution(self) -> tuple[int, int]: - """Retrieves video width and height in pixels. - - Returns: - A tuple `(width, height)` representing the frame dimensions. - """ - width = int(self._capture.get(cv2.CAP_PROP_FRAME_WIDTH)) - height = int(self._capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) - return (width, height) - - def extract_frames(self, working_dir: pathlib.Path) -> None: - """Extracts all frames from the video and saves them as PNG images. - - If it doesn't exist, creates a ``frames`` subdirectory inside - ``working_dir``, reads each frame from the video and - writes them to disk. - - Parameters: - working_dir: Directory in which to create a `frames` folder and - save extracted PNG images. - - """ - frames_dir = working_dir / "frames" - frames_dir.mkdir(exist_ok=True) - - self._capture.set(cv2.CAP_PROP_POS_FRAMES, 0) - total_frames = self.count_frames() - - for frame_idx in range(total_frames): - _, frame = self._capture.read() - frame_filename = frames_dir / f"{frame_idx}.png" - cv2.imwrite(str(frame_filename), frame) diff --git a/src/dynsight/_internal/vision/vision.py b/src/dynsight/_internal/vision/vision.py new file mode 100644 index 00000000..22510407 --- /dev/null +++ b/src/dynsight/_internal/vision/vision.py @@ -0,0 +1,639 @@ +"""dynsight.vision module for particle detection from media files.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Callable + +import numpy as np +import torch +import yaml +from PIL import Image +from ultralytics import YOLO + +if TYPE_CHECKING: + from ultralytics.engine.results import Results + from ultralytics.utils.metrics import DetMetrics + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", +) + +logger = logging.getLogger(__name__) + +# Defaults hyperparameters dictionary. +default_hyperparams = { + "lr0": 0.01, + "lrf": 0.01, + "momentum": 0.937, + "weight_decay": 0.0005, + "warmup_epochs": 3.0, + "warmup_momentum": 0.8, + "box": 7.5, + "cls": 0.5, + "dfl": 1.5, + "hsv_h": 0.015, + "hsv_s": 0.7, + "hsv_v": 0.4, + "degrees": 0.0, + "translate": 0.1, + "scale": 0.5, + "shear": 0.0, + "perspective": 0.0, + "flipud": 0.0, + "fliplr": 0.5, + "bgr": 0.0, + "mosaic": 1, + "mixup": 0.0, + "cutmix": 0.0, + "copy_paste": 0.0, +} + + +class VisionInstance: + def __init__( + self, + source: str | Path, + output_path: Path, + model: str | Path = "yolo12n.pt", + device: str | None = None, + workers: int = 8, + ) -> None: + """Class for performing computer vision tasks using YOLO models. + + This class supports object detection, Convolutional Neural Network + (CNN) training and fine-tuning, as well as the creation and management + of training datasets. + + .. caution:: + This class is still under development and may not function as + intended. + + Parameters: + source: + The source of the images or videos to be processed. For the + list of the possible sources, we refer the user to the + following `sources table `_. + For the list of the supported formats see this `formats table `_. + + output_path: + The path to save the output folder. + + model: + The path to the YOLO model file. Defaults to "yolo12n.pt". See + `here `_ for more + information. + + device: + Allows users to select between cpu, a specific gpu ID or + "mps" for MacOS users to perform the calculation + ("cuda:0" or "0" for GPUs, "cpu" or "mps" for MacOS). + + workers: + Number of worker threads for data loading. Influences the speed + of data preprocessing and feeding into the model, especially + useful in multi-GPU setups. (only for training sessions). + + """ + self.output_path = Path(output_path) + self.training_data_yaml: Path | None = None + + self.model = YOLO(model) + self.source = source + self.device = self._normalize_device_string(device) + self.workers = workers + + self.prediction_results: list[Results] | None = None + self.training_results: DetMetrics | None = None + + self._check_device() + + def set_training_dataset(self, training_data_yaml: Path) -> None: + """Set the training dataset for the model training. + + Training dataset are setted through a ``yaml`` file that should have + the following structure: + + .. code-block:: yaml + + path: path/to/dataset/folder + train: path/to/train/images + val: path/to/val/images + + nc: number_of_classes + names: [class1, class2, ...] + + With a dataset folder structure like this: + + .. code-block:: none + + dataset/ + ├── images/ + │ ├── train/ + │ │ ├── 1.jpg + │ │ ├── 2.jpg + │ │ └── ... + │ └── val/ + │ ├── 5.jpg + │ ├── 6.jpg + │ └── ... + └── labels/ + ├── train/ + │ ├── 1.txt + │ ├── 2.txt + │ └── ... + └── val/ + ├── 5.txt + ├── 6.txt + └── ... + + + Parameters: + training_data_yaml: + Path to the training data YAML file. + """ + self.training_data_yaml = training_data_yaml + + def predict( + self, + prediction_title: str, + augment: bool = False, + agnostic_nms: bool = False, + show_labels: bool = False, + class_filter: list[int] | None = None, + confidence: float = 0.25, + iou: float = 0.7, + imgsz: int | tuple[int, int] = 640, + max_det: int = 500, + ) -> None: + """Detect objects within the source. + + Parameters: + prediction_title: + The name of the prediction session. + + augment: + Enables test-time augmentation (TTA) for predictions, + potentially improving detection robustness at the cost of + inference speed. + + agnostic_nms: + Enables class-agnostic Non-Maximum Suppression (NMS), which + merges overlapping boxes of different classes. Useful in + multi-class detection scenarios where class overlap is common. + + show_labels: + Show labels names in the detected source version. + + class_filter: + Filters predictions to a set of class IDs. Only detections + belonging to the specified classes will be returned. + + confidence: + Sets the minimum confidence threshold for detections. + Objects detected with confidence below this threshold will + be disregarded. + + iou: + Lower values result in fewer detections by eliminating + overlapping boxes, useful for reducing duplicates. + + imgsz: + Defines the image size for inference. Can be a single integer + for square resizing or a tuple. Proper sizing can improve + detection accuracy and processingspeed. + + max_det: + The maximum number of detections for a single frame / image. + + """ + self.prediction_results = self.model.predict( + source=self.source, + save=True, + save_txt=False, + save_conf=True, + show_labels=show_labels, + name=prediction_title, + project=self.output_path, + device=self.device, + augment=augment, + agnostic_nms=agnostic_nms, + classes=class_filter, + conf=confidence, + iou=iou, + imgsz=imgsz, + max_det=max_det, + ) + + def create_dataset_from_predictions( + self, + dataset_name: str, + train_split: float = 0.8, + load_dataset: bool = True, + ) -> None: + """Create a YOLO training dataset from ``predict`` results. + + Parameters: + dataset_name: + Name of the dataset that will be created. + + train_split: + Fraction of images to be used as training set, the remaining + fraction will be used for the validation set. + + load_dataset: + Directly load the dataset for the next training sessions. + """ + if self.prediction_results is None: + msg = "No prediction results available." + raise ValueError(msg) + + dataset_path = self.output_path / dataset_name + images_train = dataset_path / "images" / "train" + images_val = dataset_path / "images" / "val" + labels_train = dataset_path / "labels" / "train" + labels_val = dataset_path / "labels" / "val" + + images_train.mkdir(parents=True, exist_ok=True) + images_val.mkdir(parents=True, exist_ok=True) + labels_train.mkdir(parents=True, exist_ok=True) + labels_val.mkdir(parents=True, exist_ok=True) + + names = self.prediction_results[0].names + + sorted_results = sorted(self.prediction_results, key=lambda r: r.path) + + num_train = int(len(sorted_results) * train_split) + + video_exts = {".mp4", ".avi", ".mov", ".mkv", ".webm"} + is_video = False + if ( + isinstance(self.source, (str, Path)) + and Path(self.source).suffix.lower() in video_exts + ): + is_video = True + + for idx, result in enumerate(sorted_results): + src = Path(result.path) + subset = "train" if idx < num_train else "val" + if is_video: + frame_name = f"{src.stem}_{idx:06d}.jpg" + img_dst = dataset_path / "images" / subset / frame_name + lbl_dst = ( + dataset_path + / "labels" + / subset + / (Path(frame_name).stem + ".txt") + ) + + img = Image.fromarray(result.orig_img[..., ::-1]) + img.save(img_dst) + else: + img_dst = dataset_path / "images" / subset / src.name + lbl_dst = ( + dataset_path / "labels" / subset / (src.stem + ".txt") + ) + + img_dst.write_bytes(src.read_bytes()) + + boxes = result.boxes + if boxes is None: + lbl_dst.write_text("") + continue + + xywhn = boxes.xywhn + classes = boxes.cls + with lbl_dst.open("w") as f: + for xywh, cls in zip(xywhn, classes): + f.write( + f"{int(cls)} {xywh[0]:.6f} {xywh[1]:.6f} " + f"{xywh[2]:.6f} {xywh[3]:.6f}\n" + ) + + dataset_yaml = dataset_path / "dataset.yaml" + yaml_content = { + "path": str(dataset_path.resolve()), + "train": "images/train", + "val": "images/val", + "nc": len(names), + "names": [names[i] for i in range(len(names))], + } + with dataset_yaml.open("w") as f: + yaml.safe_dump(yaml_content, f) + + if load_dataset: + self.training_data_yaml = dataset_yaml + + def tune_hyperparams( + self, + iterations: int = 15, + epochs: int = 50, + imgsz: int | tuple[int, int] = 640, + batch_size: int = 16, + ) -> dict[str, float]: + """Tune hyperparameters for the model. + + Optimize the CNN hyperparameters by leveraging the Ultralytics YOLO + `genetic algorithm `_. + It returns a dictionary of the best hyperparameters, which can be + directly used as input to the hyperparameters parameter in the train + method. + + Parameters: + iterations: + The number of exploring iterations. The higher the number, the + more accurate the results will be, increasing the computational + cost. + + epochs: + The number of epochs to perform for each iteration. Each epoch + represents a full pass over the entire dataset. + + imgsz: + Defines the image size for inference. Can be a single integer + for square resizing or a tuple. Proper sizing can improve + detection accuracy and processing speed. + + batch_size: + Three modes available: set as an integer (batch=16), + auto mode for 60% GPU memory utilization (batch=-1), or auto + mode with specified utilization fraction (batch=0.70). + """ + if self.training_data_yaml is None: + msg = "Training dataset has not been set." + raise ValueError(msg) + + self.model.tune( + data=self.training_data_yaml, + epochs=epochs, + iterations=iterations, + project=self.output_path / "tuning", + name="results", + device=self.device, + imgsz=imgsz, + batch=batch_size, + ) + yaml_path = ( + self.output_path + / "tuning" + / "results" + / "best_hyperparameters.yaml" + ) + with yaml_path.open("r") as f: + return yaml.safe_load(f) + + def train( + self, + title: str, + hyperparams: dict[str, float] | None = None, + epochs: int = 100, + batch_size: int = 16, + patience: int = 20, + imgsz: int | tuple[int, int] = 640, + ) -> None: + """Train a custom model using a training dataset. + + This function trains a custom model using a training dataset. The + dataset should be set before calling this function with the + ``set_training_data`` method. + + Parameters: + title: + The name of the resulting model. + + hyperparams: + The dictionary that contains all the hyperparameters for the + model training. The following default ``dict`` is used if not + provided: + + .. code-block:: python + + # Defaults hyperparameters dictionary. + default_hyperparams = { + "lr0": 0.01, + "lrf": 0.01, + "momentum": 0.937, + "weight_decay": 0.0005, + "warmup_epochs": 3.0, + "warmup_momentum": 0.8, + "box": 7.5, + "cls": 0.5, + "dfl": 1.5, + "hsv_h": 0.015, + "hsv_s": 0.7, + "hsv_v": 0.4, + "degrees": 0.0, + "translate": 0.1, + "scale": 0.5, + "shear": 0.0, + "perspective": 0.0, + "flipud": 0.0, + "fliplr": 0.5, + "bgr": 0.0, + "mosaic": 1, + "mixup": 0.0, + "cutmix": 0.0, + "copy_paste": 0.0 + } + + Manually customize this ``dict`` to change the training + performance or use the ``tune_hyperparams`` method to + automatically optimize hyperparameters. + + epochs: + Total number of training epochs. Each epoch represents a full + pass over the entire dataset. + + batch_size: + Three modes available: set as an integer (batch=16), + auto mode for 60% GPU memory utilization (batch=-1), or auto + mode with specified utilization fraction (batch=0.70). + + patience: + Number of epochs to wait without improvement in validation + metrics before early stopping the training. Helps to prevent + overfitting. + + imgsz: + Defines the image size for inference. Can be a single integer + for square resizing or a tuple. Proper sizing can improve + detection accuracy and processing speed. + + """ + if self.training_data_yaml is None: + msg = "Training dataset has not been set." + raise ValueError(msg) + + full_params = default_hyperparams.copy() + if hyperparams is not None: + unknown_keys = set(hyperparams.keys()) - set(full_params.keys()) + if unknown_keys: + msg = f"Unknown hyperparameters: {unknown_keys}" + raise ValueError(msg) + for key in hyperparams: + full_params[key] = hyperparams[key] + + self.training_results = self.model.train( + data=self.training_data_yaml, + epochs=epochs, + imgsz=imgsz, + batch=batch_size, + workers=self.workers, + name=title, + project=self.output_path, + patience=patience, + device=self.device, + **full_params, + ) + self.model = YOLO(self.output_path / title / "weights" / "best.pt") + + def export_prediction_to_xyz( + self, file_name: Path, class_filter: list[int] | None = None + ) -> Path: + """Export prediction results into a single ``.xyz`` file. + + Each frame of the resulting ``.xyz`` corresponds to one of the + images/frames present in the source and used in the ``predict`` method. + + Parameters: + file_name: + File name for the ``.xyz`` file. + + class_filter: + Limit exported detections to the specified class IDs. If + ``None`` all detected objects will be exported. + + Returns: + Path to the exported ``.xyz`` file. + """ + if self.prediction_results is None: + msg = "No prediction results available." + raise ValueError(msg) + + sorted_results = sorted(self.prediction_results, key=lambda r: r.path) + file_path = self.output_path / file_name + + with file_path.open("w") as f: + for result in sorted_results: + boxes = result.boxes + + coords: list[str] = [] + if boxes is not None: + xyxy_raw = boxes.xyxy + if isinstance(xyxy_raw, torch.Tensor): + xyxy = xyxy_raw.cpu().numpy() + else: + xyxy = np.asarray(xyxy_raw) + + cls_raw = boxes.cls + if isinstance(cls_raw, torch.Tensor): + classes = cls_raw.cpu().numpy().astype(int) + else: + classes = np.asarray(cls_raw).astype(int) + for (x1, y1, x2, y2), cls_id in zip(xyxy, classes): + if ( + class_filter is not None + and cls_id not in class_filter + ): + continue + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + coords.append(f"{cls_id} {cx:.6f} {cy:.6f} 0.0") + + f.write(f"{len(coords)}\n") + f.write("class x y z\n") + for line in coords: + f.write(f"{line}\n") + return file_path + + def _normalize_device_string(self, device: str | None) -> str: + """Normalize device string to match Ultralytics expectations.""" + if device is None: + return "0" if torch.cuda.is_available() else "cpu" + + device = str(device).lower() + + if device in {"cpu", "mps", "cuda"}: + return device + + # Allow "cuda:0" -> "0", "cuda:0,1" -> "0,1" + if device.startswith("cuda:"): + return device.replace("cuda:", "") + + # Allow "0", "0,1", etc. + if all(part.strip().isdigit() for part in device.split(",")): + return device + msg = f"Unsupported device string: '{device}'" + raise ValueError(msg) + + def _check_device(self) -> None: + """Verify and validate the selected device for compatibility.""" + self.device = self._normalize_device_string(self.device) + + def _device_error(msg: str) -> None: + raise RuntimeError(msg) + + try: + if self.device == "cpu": + self._check_cpu_device() + elif self.device == "mps": + self._check_mps_device(_device_error) + elif self.device == "cuda": + self._check_single_cuda_device(_device_error) + elif all( + part.strip().isdigit() for part in self.device.split(",") + ): + self._check_multi_cuda_devices(_device_error) + else: + _device_error(f"Unsupported device string: '{self.device}'") + except (ValueError, RuntimeError, IndexError, OSError) as e: + _device_error(str(e)) + + def _check_cpu_device(self) -> None: + logger.info("Using CPU.") + + def _check_mps_device(self, _device_error: Callable[[str], None]) -> None: + if not ( + hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + ): + _device_error("MPS device requested but not available.") + logger.info("Using Apple MPS backend.") + + def _check_single_cuda_device( + self, _device_error: Callable[[str], None] + ) -> None: + if not torch.cuda.is_available(): + _device_error("CUDA requested but not available.") + name = torch.cuda.get_device_name(0) + backend = "ROCm" if torch.version.hip else "CUDA" + mem_free, mem_total = torch.cuda.mem_get_info(0) + logger.info(f"Using GPU 0: {name} [{backend}]") + logger.info( + "Memory: %.1f MB free / %.1f MB total", + mem_free / 1024**2, + mem_total / 1024**2, + ) + _ = torch.tensor([0.0]).to("cuda:0") + + def _check_multi_cuda_devices( + self, _device_error: Callable[[str], None] + ) -> None: + gpus = [int(d) for d in self.device.split(",")] + for idx in gpus: + if idx >= torch.cuda.device_count(): + _device_error( + f"Requested GPU index {idx}, but only " + f"{torch.cuda.device_count()} available." + ) + name = torch.cuda.get_device_name(idx) + mem_free, mem_total = torch.cuda.mem_get_info(idx) + logger.info(f"Using GPU {idx}: {name}") + logger.info( + "Memory: %.1f MB free / %.1f MB total", + mem_free / 1024**2, + mem_total / 1024**2, + ) + _ = torch.tensor([0.0]).to(f"cuda:{gpus[0]}") diff --git a/src/dynsight/_internal/vision/vision_gui.py b/src/dynsight/_internal/vision/vision_gui.py deleted file mode 100644 index d15b4ec0..00000000 --- a/src/dynsight/_internal/vision/vision_gui.py +++ /dev/null @@ -1,220 +0,0 @@ -from __future__ import annotations - -import pathlib -import tkinter as tk -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -from PIL import Image - - -@dataclass -class Box: - id: int - center_x: float - center_y: float - width: float - height: float - abs_coords: tuple[int, int, int, int] - - -class VisionGUI: - """GUI for interactively labeling images by drawing bounding boxes. - - * Author: Simone Martino - """ - - def __init__( - self, - master: tk.Tk, - image_path: pathlib.Path, - destination_folder: pathlib.Path = Path(__file__).parent, - ) -> None: - self.master = master - self.image_path = image_path - self.destination_folder = destination_folder - self.master.title("Dynsight: Label tool") - try: - self.image = tk.PhotoImage(file=image_path) - except Exception as e: - msg = f"Error loading image: {e}" - raise ValueError(msg) from e - # Setup the main grid - self.master.rowconfigure(index=0, weight=1) - # Image - self.master.columnconfigure(index=0, weight=1) - # Sidebar - self.master.columnconfigure(index=1, weight=1) - - # Image canvas - self.canvas = tk.Canvas( - master=self.master, - width=self.image.width(), - height=self.image.height(), - cursor="crosshair", - ) - self.canvas.grid(row=0, column=0, sticky="nsew") - self.canvas.create_image(0, 0, anchor=tk.NW, image=self.image) - # Rulers - self.h_line = self.canvas.create_line( - 0, # x0 - 0, # y0 - self.image.width(), # x1 - 0, # y1 - fill="blue", - dash=(2, 2), - width=3, - ) - self.v_line = self.canvas.create_line( - 0, # x0 - 0, # y0 - 0, # x1 - self.image.height(), # y1 - fill="blue", - dash=(2, 2), - width=3, - ) - # Sidebar - self.sidebar = tk.Frame( - master=self.master, - width=150, - padx=10, - pady=10, - ) - self.sidebar.grid(row=0, column=1, sticky="ns") - self.sidebar.grid_propagate(flag=False) - - # Buttons - self.submit_button = tk.Button( - self.sidebar, - text="Submit", - command=self._submit, - ) - self.submit_button.pack(pady=10, fill="x") - - self.undo_button = tk.Button( - self.sidebar, - text="Undo", - command=self._undo, - ) - self.undo_button.pack(pady=10, fill="x") - - self.close_button = tk.Button( - self.sidebar, - text="Close", - command=self._close, - ) - self.close_button.pack(pady=10, fill="x") - - # Labelling variables - self.start_x = 0 - self.start_y = 0 - self.current_box = 0 - self.boxes: list[Box] = [] - - # Mouse bindings - self.canvas.bind("", self._on_click_press) - self.canvas.bind("", self._on_click_release) - self.canvas.bind("", self._on_mouse_drag) - self.canvas.bind("", self._follow_mouse) - - # Mouse functions - def _on_click_press(self, event: tk.Event[Any]) -> None: - """Starts drawing the box on mouse press.""" - self.start_x, self.start_y = event.x, event.y - self.current_box = self.canvas.create_rectangle( - self.start_x, # x0 - self.start_y, # y0 - self.start_x, # x1 - self.start_y, # y1 - outline="red", - width=3, - ) - - def _on_click_release(self, event: tk.Event[Any]) -> None: - """Finalize the box on mouse release.""" - x2, y2 = event.x, event.y - x1, y1 = self.start_x, self.start_y - abs_coords = (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)) - center_x = (x1 + x2) / (2 * self.image.width()) - center_y = (y1 + y2) / (2 * self.image.height()) - width_rel = abs(x2 - x1) / self.image.width() - height_rel = abs(y2 - y1) / self.image.height() - box_info = Box( - id=self.current_box, - center_x=center_x, - center_y=center_y, - width=width_rel, - height=height_rel, - abs_coords=abs_coords, - ) - self.boxes.append(box_info) - self.current_box = 0 - - def _on_mouse_drag(self, event: tk.Event[Any]) -> None: - """Update the box coordinates while dragging the mouse.""" - sel_x, sel_y = event.x, event.y - self.canvas.coords( - self.current_box, # ID - self.start_x, # x0 - self.start_y, # y0 - sel_x, # x1 - sel_y, # y1 - ) - # Sync rulers too - self.canvas.coords( - self.h_line, # ID - 0, # x0 - sel_y, # y0 - self.image.width(), # x1 - sel_y, # y1 - ) - self.canvas.coords( - self.v_line, # ID - sel_x, # x0 - 0, # y0 - sel_x, # x1 - self.image.height(), # y1 - ) - - def _follow_mouse(self, event: tk.Event[Any]) -> None: - """Sync guide lines position with mouse movement.""" - x, y = event.x, event.y - self.canvas.coords( - self.h_line, # ID - 0, # x0 - y, # y0 - self.image.width(), # x1 - y, # y1 - ) - self.canvas.coords( - self.v_line, # ID - x, # x0 - 0, # y0 - x, # x1 - self.image.height(), # y1 - ) - - # Button functions - def _submit(self) -> None: - """Save the cropped images.""" - cropped_img_folder = self.destination_folder / "training_items" - cropped_img_folder.mkdir(exist_ok=True) - pil_image = Image.open(self.image_path) - for i, box in enumerate(self.boxes): - abs_coords = box.abs_coords - cropped_image = pil_image.crop(abs_coords) - save_path = cropped_img_folder / f"{i}.png" - cropped_image.save(save_path) - self.master.destroy() - - def _undo(self) -> None: - """Remove the last drawn box.""" - if self.boxes: - last_box = self.boxes.pop() - self.canvas.delete(last_box.id) - - def _close(self) -> None: - """Close without saving.""" - self.master.destroy() diff --git a/src/dynsight/_internal/vision/vision_utilities.py b/src/dynsight/_internal/vision/vision_utilities.py deleted file mode 100644 index 213c7077..00000000 --- a/src/dynsight/_internal/vision/vision_utilities.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import matplotlib.pyplot as plt -import numpy as np -from scipy.optimize import curve_fit -from scipy.stats import norm - -if TYPE_CHECKING: - import pathlib - - from numpy.typing import NDArray - - -def find_outliers( - distribution: NDArray[np.float64], - save_path: pathlib.Path, - fig_name: str, - thr: float = 1e-5, -) -> NDArray[np.float64]: - """Detects outliers in a distribution by fitting a normal distribution.""" - if distribution.size == 0: - msg = "Distribution is empty or contains only NaNs/Infs." - raise ValueError(msg) - if np.std(distribution) == 0: - return np.array([]) - - def _gaussian( - x: NDArray[np.float64], mu: float, sigma: float, amplitude: float - ) -> NDArray[np.float64]: - return amplitude * norm.pdf(x, mu, sigma) - - # Compute histogram and bin centers - hist, bin_edges = np.histogram(distribution, bins="auto", density=True) - bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 - - mask = np.isfinite(hist) & (hist > 0) - hist = hist[mask] - bin_centers = bin_centers[mask] - min_n_bins = 3 - if len(hist) < min_n_bins: - return np.array([]) - # Fit the Gaussian curve to the histogram data - popt, _ = curve_fit( - _gaussian, - bin_centers, - hist, - p0=[np.mean(distribution), np.std(distribution), np.max(hist)], - ) - mu, sigma, amplitude = popt - - # Generate fitted Gaussian curve for plotting - x = np.linspace(bin_edges[0], bin_edges[-1], 1000) - fitted_curve = _gaussian(x, mu, sigma, amplitude) - - # Calculate PDF threshold-based cutoffs - base_pdf = amplitude / (sigma * np.sqrt(2 * np.pi)) - x_threshold_min = mu - np.sqrt(-2 * sigma**2 * np.log(thr / base_pdf)) - x_threshold_max = mu + np.sqrt(-2 * sigma**2 * np.log(thr / base_pdf)) - - # Identify outliers using numpy boolean indexing - outliers: NDArray[np.float64] = distribution[ - (distribution < x_threshold_min) | (distribution > x_threshold_max) - ] - - # Plot histogram, fitted curve, and threshold lines - plt.hist( - distribution, - bins="auto", - density=True, - alpha=0.6, - color="g", - label="Histogram", - ) - plt.plot( - x, - fitted_curve, - "k-", - linewidth=2, - label=rf"Gaussian fit $\mu={mu:.2f},\ \sigma={sigma:.2f}$", - ) - plt.axvline( - x_threshold_min, - color="r", - linestyle="--", - label=f"Threshold Min = {x_threshold_min:.2f}", - ) - plt.axvline( - x_threshold_max, - color="b", - linestyle="--", - label=f"Threshold Max = {x_threshold_max:.2f}", - ) - plt.legend(loc="best") - plt.title("Histogram with Gaussian Fit and Thresholds") - plt.xlabel("Values") - plt.ylabel("Density") - plt.tight_layout() - plt.savefig(save_path / fig_name) - plt.close() - - return outliers diff --git a/src/dynsight/track.py b/src/dynsight/track.py new file mode 100644 index 00000000..b88355f5 --- /dev/null +++ b/src/dynsight/track.py @@ -0,0 +1,9 @@ +"""track package.""" + +from dynsight._internal.track.track import ( + track_xyz, +) + +__all__ = [ + "track_xyz", +] diff --git a/src/dynsight/utilities.py b/src/dynsight/utilities.py index b1315855..ccd1b422 100644 --- a/src/dynsight/utilities.py +++ b/src/dynsight/utilities.py @@ -4,10 +4,12 @@ find_extrema_points, load_or_compute_soap, normalize_array, + read_xyz, ) __all__ = [ "find_extrema_points", "load_or_compute_soap", "normalize_array", + "read_xyz", ] diff --git a/src/dynsight/vision.py b/src/dynsight/vision.py index 6c83a813..590004fc 100644 --- a/src/dynsight/vision.py +++ b/src/dynsight/vision.py @@ -1,13 +1,13 @@ """Vision package.""" -from dynsight._internal.vision.detect import ( - Detect, +from dynsight._internal.vision.label_tool import ( + label_tool, ) -from dynsight._internal.vision.video_to_frame import ( - Video, +from dynsight._internal.vision.vision import ( + VisionInstance, ) __all__ = [ - "Detect", - "Video", + "VisionInstance", + "label_tool", ] diff --git a/tests/systems/lj_id.xyz b/tests/systems/lj_id.xyz new file mode 100644 index 00000000..81e90a7e --- /dev/null +++ b/tests/systems/lj_id.xyz @@ -0,0 +1,210 @@ +5 +Step 0 +P0 -5.000000 2.000000 0.0 0 +P1 4.000000 -3.000000 0.0 1 +P2 1.500000 6.000000 0.0 2 +P3 -2.500000 -4.500000 0.0 3 +P4 7.000000 3.500000 0.0 4 +5 +Step 1 +P0 -4.791128 2.036963 0.0 0 +P1 3.958151 -2.823608 0.0 1 +P2 1.447059 5.679399 0.0 2 +P3 -2.285658 -4.236846 0.0 3 +P4 6.816060 3.440246 0.0 4 +5 +Step 2 +P0 -4.511325 2.051341 0.0 0 +P1 3.697118 -2.574102 0.0 1 +P2 1.301944 5.384593 0.0 2 +P3 -2.092202 -3.880785 0.0 3 +P4 6.476730 3.246345 0.0 4 +5 +Step 3 +P0 -4.195438 2.112931 0.0 0 +P1 3.553053 -2.497301 0.0 1 +P2 1.389577 5.269173 0.0 2 +P3 -1.994445 -3.639558 0.0 3 +P4 6.115303 3.004959 0.0 4 +5 +Step 4 +P0 -3.980925 2.172511 0.0 0 +P1 3.427366 -2.235810 0.0 1 +P2 1.260627 4.994255 0.0 2 +P3 -1.853817 -3.349635 0.0 3 +P4 5.963871 2.870469 0.0 4 +5 +Step 5 +P0 -3.626436 2.108306 0.0 0 +P1 3.379681 -1.980356 0.0 1 +P2 1.241300 4.876043 0.0 2 +P3 -1.768785 -3.177578 0.0 3 +P4 5.615117 2.777250 0.0 4 +5 +Step 6 +P0 -3.359578 2.148831 0.0 0 +P1 3.160072 -1.882366 0.0 1 +P2 1.122169 4.787745 0.0 2 +P3 -1.517165 -3.049852 0.0 3 +P4 5.496275 2.670436 0.0 4 +5 +Step 7 +P0 -3.201726 2.130611 0.0 0 +P1 3.020153 -1.620133 0.0 1 +P2 1.156847 4.717972 0.0 2 +P3 -1.273898 -2.930248 0.0 3 +P4 5.255822 2.626861 0.0 4 +5 +Step 8 +P0 -2.953675 2.044091 0.0 0 +P1 2.961026 -1.434917 0.0 1 +P2 1.219418 4.486570 0.0 2 +P3 -1.261366 -2.650880 0.0 3 +P4 4.923909 2.479460 0.0 4 +5 +Step 9 +P0 -2.667821 1.878745 0.0 0 +P1 2.785395 -1.319951 0.0 1 +P2 1.141136 4.276260 0.0 2 +P3 -1.238114 -2.560183 0.0 3 +P4 4.626054 2.317652 0.0 4 +5 +Step 10 +P0 -2.576834 1.713102 0.0 0 +P1 2.749895 -1.163980 0.0 1 +P2 1.116430 4.014591 0.0 2 +P3 -1.210798 -2.358824 0.0 3 +P4 4.384192 2.300885 0.0 4 +5 +Step 11 +P0 -2.436860 1.617456 0.0 0 +P1 2.752405 -1.144044 0.0 1 +P2 1.005449 3.977450 0.0 2 +P3 -0.991517 -2.289499 0.0 3 +P4 4.117488 2.191018 0.0 4 +5 +Step 12 +P0 -2.270411 1.667002 0.0 0 +P1 2.676300 -1.114209 0.0 1 +P2 1.030007 3.793655 0.0 2 +P3 -0.926276 -2.209220 0.0 3 +P4 4.070431 2.030772 0.0 4 +5 +Step 13 +P0 -2.214457 1.650766 0.0 0 +P1 2.612264 -1.110300 0.0 1 +P2 1.053063 3.602497 0.0 2 +P3 -0.872838 -2.096902 0.0 3 +P4 3.818307 2.083224 0.0 4 +5 +Step 14 +P0 -2.096298 1.607129 0.0 0 +P1 2.416233 -1.016174 0.0 1 +P2 1.006852 3.578933 0.0 2 +P3 -0.844236 -1.891815 0.0 3 +P4 3.671041 2.127046 0.0 4 +5 +Step 15 +P0 -1.905863 1.464007 0.0 0 +P1 2.277212 -0.999618 0.0 1 +P2 1.034141 3.334519 0.0 2 +P3 -0.800417 -1.755370 0.0 3 +P4 3.600252 1.948550 0.0 4 +5 +Step 16 +P0 -1.741268 1.453377 0.0 0 +P1 2.092067 -0.990299 0.0 1 +P2 1.099642 3.149954 0.0 2 +P3 -0.785758 -1.540086 0.0 3 +P4 3.553555 1.873464 0.0 4 +5 +Step 17 +P0 -1.586963 1.370763 0.0 0 +P1 1.994711 -0.833266 0.0 1 +P2 1.116334 3.086626 0.0 2 +P3 -0.712209 -1.331583 0.0 3 +P4 3.299454 1.874486 0.0 4 +5 +Step 18 +P0 -1.498039 1.257161 0.0 0 +P1 1.925974 -0.777464 0.0 1 +P2 1.140884 3.098851 0.0 2 +P3 -0.709150 -1.166004 0.0 3 +P4 3.227851 1.939879 0.0 4 +5 +Step 19 +P0 -1.339758 1.245789 0.0 0 +P1 1.929810 -0.791418 0.0 1 +P2 1.153426 2.940205 0.0 2 +P3 -0.530508 -0.940736 0.0 3 +P4 3.222744 1.820487 0.0 4 +5 +Step 20 +P0 -1.280262 1.300132 0.0 0 +P1 1.838503 -0.794955 0.0 1 +P2 1.171411 2.763938 0.0 2 +P3 -0.449084 -0.757035 0.0 3 +P4 3.125288 1.750279 0.0 4 +5 +Step 21 +P0 -1.121547 1.382616 0.0 0 +P1 1.786073 -0.777513 0.0 1 +P2 1.271980 2.716131 0.0 2 +P3 -0.259758 -0.707211 0.0 3 +P4 2.933658 1.759766 0.0 4 +5 +Step 22 +P0 -1.105333 1.302314 0.0 0 +P1 1.689657 -0.701603 0.0 1 +P2 1.296014 2.708494 0.0 2 +P3 -0.223292 -0.522848 0.0 3 +P4 2.732958 1.772295 0.0 4 +5 +Step 23 +P0 -1.034569 1.318364 0.0 0 +P1 1.641425 -0.698426 0.0 1 +P2 1.382320 2.722614 0.0 2 +P3 -0.136953 -0.448927 0.0 3 +P4 2.680647 1.807876 0.0 4 +5 +Step 24 +P0 -0.997685 1.245300 0.0 0 +P1 1.656599 -0.682829 0.0 1 +P2 1.415509 2.693962 0.0 2 +P3 -0.134423 -0.412289 0.0 3 +P4 2.708580 1.783989 0.0 4 +5 +Step 25 +P0 -0.798155 1.256512 0.0 0 +P1 1.680035 -0.563239 0.0 1 +P2 1.341091 2.629722 0.0 2 +P3 -0.115892 -0.449490 0.0 3 +P4 2.602751 1.785441 0.0 4 +5 +Step 26 +P0 -0.781771 1.142838 0.0 0 +P1 1.756680 -0.396595 0.0 1 +P2 1.276294 2.642729 0.0 2 +P3 -0.137163 -0.365647 0.0 3 +P4 2.607984 1.749603 0.0 4 +5 +Step 27 +P0 -0.738512 1.238783 0.0 0 +P1 1.777029 -0.266889 0.0 1 +P2 1.351501 2.612630 0.0 2 +P3 -0.130518 -0.231435 0.0 3 +P4 2.599496 1.639626 0.0 4 +5 +Step 28 +P0 -0.656927 1.286644 0.0 0 +P1 1.786953 -0.246108 0.0 1 +P2 1.240883 2.459362 0.0 2 +P3 0.025398 -0.287629 0.0 3 +P4 2.516386 1.577458 0.0 4 +5 +Step 29 +P0 -0.460374 1.358828 0.0 0 +P1 1.670440 -0.141020 0.0 1 +P2 1.168045 2.434449 0.0 2 +P3 -0.040058 -0.302320 0.0 3 +P4 2.553050 1.458213 0.0 4 diff --git a/tests/systems/lj_noid.xyz b/tests/systems/lj_noid.xyz new file mode 100644 index 00000000..a4ca14ba --- /dev/null +++ b/tests/systems/lj_noid.xyz @@ -0,0 +1,210 @@ +5 +Step 0 +P0 -5.000000 2.000000 0.0 +P1 4.000000 -3.000000 0.0 +P2 1.500000 6.000000 0.0 +P3 -2.500000 -4.500000 0.0 +P4 7.000000 3.500000 0.0 +5 +Step 1 +P0 -4.791128 2.036963 0.0 +P1 3.958151 -2.823608 0.0 +P2 1.447059 5.679399 0.0 +P3 -2.285658 -4.236846 0.0 +P4 6.816060 3.440246 0.0 +5 +Step 2 +P0 -4.511325 2.051341 0.0 +P1 3.697118 -2.574102 0.0 +P2 1.301944 5.384593 0.0 +P3 -2.092202 -3.880785 0.0 +P4 6.476730 3.246345 0.0 +5 +Step 3 +P0 -4.195438 2.112931 0.0 +P1 3.553053 -2.497301 0.0 +P2 1.389577 5.269173 0.0 +P3 -1.994445 -3.639558 0.0 +P4 6.115303 3.004959 0.0 +5 +Step 4 +P0 -3.980925 2.172511 0.0 +P1 3.427366 -2.235810 0.0 +P2 1.260627 4.994255 0.0 +P3 -1.853817 -3.349635 0.0 +P4 5.963871 2.870469 0.0 +5 +Step 5 +P0 -3.626436 2.108306 0.0 +P1 3.379681 -1.980356 0.0 +P2 1.241300 4.876043 0.0 +P3 -1.768785 -3.177578 0.0 +P4 5.615117 2.777250 0.0 +5 +Step 6 +P0 -3.359578 2.148831 0.0 +P1 3.160072 -1.882366 0.0 +P2 1.122169 4.787745 0.0 +P3 -1.517165 -3.049852 0.0 +P4 5.496275 2.670436 0.0 +5 +Step 7 +P0 -3.201726 2.130611 0.0 +P1 3.020153 -1.620133 0.0 +P2 1.156847 4.717972 0.0 +P3 -1.273898 -2.930248 0.0 +P4 5.255822 2.626861 0.0 +5 +Step 8 +P0 -2.953675 2.044091 0.0 +P1 2.961026 -1.434917 0.0 +P2 1.219418 4.486570 0.0 +P3 -1.261366 -2.650880 0.0 +P4 4.923909 2.479460 0.0 +5 +Step 9 +P0 -2.667821 1.878745 0.0 +P1 2.785395 -1.319951 0.0 +P2 1.141136 4.276260 0.0 +P3 -1.238114 -2.560183 0.0 +P4 4.626054 2.317652 0.0 +5 +Step 10 +P0 -2.576834 1.713102 0.0 +P1 2.749895 -1.163980 0.0 +P2 1.116430 4.014591 0.0 +P3 -1.210798 -2.358824 0.0 +P4 4.384192 2.300885 0.0 +5 +Step 11 +P0 -2.436860 1.617456 0.0 +P1 2.752405 -1.144044 0.0 +P2 1.005449 3.977450 0.0 +P3 -0.991517 -2.289499 0.0 +P4 4.117488 2.191018 0.0 +5 +Step 12 +P0 -2.270411 1.667002 0.0 +P1 2.676300 -1.114209 0.0 +P2 1.030007 3.793655 0.0 +P3 -0.926276 -2.209220 0.0 +P4 4.070431 2.030772 0.0 +5 +Step 13 +P0 -2.214457 1.650766 0.0 +P1 2.612264 -1.110300 0.0 +P2 1.053063 3.602497 0.0 +P3 -0.872838 -2.096902 0.0 +P4 3.818307 2.083224 0.0 +5 +Step 14 +P0 -2.096298 1.607129 0.0 +P1 2.416233 -1.016174 0.0 +P2 1.006852 3.578933 0.0 +P3 -0.844236 -1.891815 0.0 +P4 3.671041 2.127046 0.0 +5 +Step 15 +P0 -1.905863 1.464007 0.0 +P1 2.277212 -0.999618 0.0 +P2 1.034141 3.334519 0.0 +P3 -0.800417 -1.755370 0.0 +P4 3.600252 1.948550 0.0 +5 +Step 16 +P0 -1.741268 1.453377 0.0 +P1 2.092067 -0.990299 0.0 +P2 1.099642 3.149954 0.0 +P3 -0.785758 -1.540086 0.0 +P4 3.553555 1.873464 0.0 +5 +Step 17 +P0 -1.586963 1.370763 0.0 +P1 1.994711 -0.833266 0.0 +P2 1.116334 3.086626 0.0 +P3 -0.712209 -1.331583 0.0 +P4 3.299454 1.874486 0.0 +5 +Step 18 +P0 -1.498039 1.257161 0.0 +P1 1.925974 -0.777464 0.0 +P2 1.140884 3.098851 0.0 +P3 -0.709150 -1.166004 0.0 +P4 3.227851 1.939879 0.0 +5 +Step 19 +P0 -1.339758 1.245789 0.0 +P1 1.929810 -0.791418 0.0 +P2 1.153426 2.940205 0.0 +P3 -0.530508 -0.940736 0.0 +P4 3.222744 1.820487 0.0 +5 +Step 20 +P0 -1.280262 1.300132 0.0 +P1 1.838503 -0.794955 0.0 +P2 1.171411 2.763938 0.0 +P3 -0.449084 -0.757035 0.0 +P4 3.125288 1.750279 0.0 +5 +Step 21 +P0 -1.121547 1.382616 0.0 +P1 1.786073 -0.777513 0.0 +P2 1.271980 2.716131 0.0 +P3 -0.259758 -0.707211 0.0 +P4 2.933658 1.759766 0.0 +5 +Step 22 +P0 -1.105333 1.302314 0.0 +P1 1.689657 -0.701603 0.0 +P2 1.296014 2.708494 0.0 +P3 -0.223292 -0.522848 0.0 +P4 2.732958 1.772295 0.0 +5 +Step 23 +P0 -1.034569 1.318364 0.0 +P1 1.641425 -0.698426 0.0 +P2 1.382320 2.722614 0.0 +P3 -0.136953 -0.448927 0.0 +P4 2.680647 1.807876 0.0 +5 +Step 24 +P0 -0.997685 1.245300 0.0 +P1 1.656599 -0.682829 0.0 +P2 1.415509 2.693962 0.0 +P3 -0.134423 -0.412289 0.0 +P4 2.708580 1.783989 0.0 +5 +Step 25 +P0 -0.798155 1.256512 0.0 +P1 1.680035 -0.563239 0.0 +P2 1.341091 2.629722 0.0 +P3 -0.115892 -0.449490 0.0 +P4 2.602751 1.785441 0.0 +5 +Step 26 +P0 -0.781771 1.142838 0.0 +P1 1.756680 -0.396595 0.0 +P2 1.276294 2.642729 0.0 +P3 -0.137163 -0.365647 0.0 +P4 2.607984 1.749603 0.0 +5 +Step 27 +P0 -0.738512 1.238783 0.0 +P1 1.777029 -0.266889 0.0 +P2 1.351501 2.612630 0.0 +P3 -0.130518 -0.231435 0.0 +P4 2.599496 1.639626 0.0 +5 +Step 28 +P0 -0.656927 1.286644 0.0 +P1 1.786953 -0.246108 0.0 +P2 1.240883 2.459362 0.0 +P3 0.025398 -0.287629 0.0 +P4 2.516386 1.577458 0.0 +5 +Step 29 +P0 -0.460374 1.358828 0.0 +P1 1.670440 -0.141020 0.0 +P2 1.168045 2.434449 0.0 +P3 -0.040058 -0.302320 0.0 +P4 2.553050 1.458213 0.0 diff --git a/tests/track/__init__.py b/tests/track/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/track/test_track.py b/tests/track/test_track.py new file mode 100644 index 00000000..fcb576d2 --- /dev/null +++ b/tests/track/test_track.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +from dynsight.track import track_xyz +from dynsight.utilities import read_xyz + + +def test_track_xyz(tmp_path: Path) -> None: + original_dir = Path(__file__).resolve().parent + filename = original_dir / "../systems/lj_noid.xyz" + file_with_id = original_dir / "../systems/lj_id.xyz" + output = tmp_path / "trajectory.xyz" + track_xyz(input_xyz=filename, output_xyz=output, search_range=10) + n_atoms = 5 + for _ in range(n_atoms): + arr1 = read_xyz( + input_xyz=output, cols_order=["name", "x", "y", "z", "ID"] + ).to_numpy() + arr2 = read_xyz( + input_xyz=file_with_id, cols_order=["name", "x", "y", "z", "ID"] + ).to_numpy() + assert arr1.shape == arr2.shape + assert np.array_equal(arr1, arr2) diff --git a/tests/vision/__init__.py b/tests/vision/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vision/test_vision.py b/tests/vision/test_vision.py new file mode 100644 index 00000000..32398fb2 --- /dev/null +++ b/tests/vision/test_vision.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import yaml +from PIL import Image + +if TYPE_CHECKING: + from pathlib import Path + +from dynsight._internal.vision.vision import VisionInstance + +DEFAULT_MODEL = "yolo12n.pt" + + +def create_dummy_yolo_dataset( + root_path: Path, + num_train: int = 5, + num_val: int = 2, + image_size: tuple[int, int] = (100, 100), + num_classes: int = 1, + class_names: list[str] | None = None, + rng: np.random.Generator | None = None, +) -> None: + if rng is None: + rng = np.random.default_rng() + if class_names is None: + class_names = [f"class_{i}" for i in range(num_classes)] + + def generate_split(split_name: str, num_images: int) -> None: + images_dir = root_path / "dataset" / "images" / split_name + labels_dir = root_path / "dataset" / "labels" / split_name + images_dir.mkdir(parents=True, exist_ok=True) + labels_dir.mkdir(parents=True, exist_ok=True) + + for i in range(num_images): + array = rng.integers( + 0, 256, size=(image_size[1], image_size[0], 3), dtype=np.uint8 + ) + img = Image.fromarray(array) + img_path = images_dir / f"img_{i}.jpg" + img.save(img_path) + + x_center = np.round(rng.uniform(0.3, 0.7), 6) + y_center = np.round(rng.uniform(0.3, 0.7), 6) + width = np.round(rng.uniform(0.1, 0.3), 6) + height = np.round(rng.uniform(0.1, 0.3), 6) + class_id = int(rng.integers(0, num_classes)) + + label_path = labels_dir / f"img_{i}.txt" + label_path.write_text( + f"{class_id} {x_center} {y_center} {width} {height}\n" + ) + + generate_split("train", num_train) + generate_split("val", num_val) + + data_yaml = { + "path": str((root_path / "dataset").resolve()), + "train": str((root_path / "dataset" / "images" / "train").resolve()), + "val": str((root_path / "dataset" / "images" / "val").resolve()), + "nc": num_classes, + "names": class_names, + } + yaml_path = root_path / "data.yaml" + yaml_path.write_text(yaml.dump(data_yaml)) + + +def test_vision_instance_creation(tmp_path: Path) -> None: + source_path = tmp_path / "source.jpg" + img = Image.new("RGB", (100, 100)) + img.save(source_path) + model_path = tmp_path / DEFAULT_MODEL + out_path = tmp_path / "output" + + instance = VisionInstance( + source=source_path, + output_path=out_path, + model=model_path, + device="cpu", + workers=1, + ) + assert model_path.exists() + assert instance.training_data_yaml is None + assert instance.training_results is None + assert instance.prediction_results is None + assert instance.device == "cpu" + + +def test_vision_training(tmp_path: Path) -> None: + source_path = tmp_path / "source.jpg" + img = Image.new("RGB", (100, 100)) + img.save(source_path) + model_path = tmp_path / DEFAULT_MODEL + out_path = tmp_path / "output" + + instance = VisionInstance( + source=source_path, + output_path=out_path, + model=model_path, + device="cpu", + workers=1, + ) + old_model = instance.model + create_dummy_yolo_dataset(tmp_path) + instance.set_training_dataset(tmp_path / "data.yaml") + assert (tmp_path / "data.yaml").exists() + instance.train( + title="test_train", + epochs=1, + batch_size=-1, + imgsz=100, + ) + new_model = instance.model + new_model_path = out_path / "test_train" / "weights" / "best.pt" + assert instance.training_results is not None + + assert str(instance.training_results.names[0]) == "class_0" + assert str(instance.training_results.task) == "detect" + + assert new_model_path.exists() + assert old_model != new_model + + +def test_vision_predict(tmp_path: Path) -> None: + out_path = tmp_path / "output" + + source_path = tmp_path / "imgs" + source_path.mkdir(parents=True, exist_ok=True) + for i in range(10): + img = Image.new("RGB", (100, 100)) + img.save(source_path / f"img_{i}.jpg") + model_path = tmp_path / DEFAULT_MODEL + + instance = VisionInstance( + source=source_path, + output_path=out_path, + model=model_path, + device="cpu", + workers=1, + ) + instance.predict(prediction_title="test_predict") + assert instance.prediction_results is not None + for i in range(10): + assert (source_path / f"img_{i}.jpg").exists() + instance.create_dataset_from_predictions("test_dataset_from_pred") + + dataset_img_t = out_path / "test_dataset_from_pred" / "images" / "train" + dataset_img_v = out_path / "test_dataset_from_pred" / "images" / "val" + dataset_lab_t = out_path / "test_dataset_from_pred" / "labels" / "train" + dataset_lab_v = out_path / "test_dataset_from_pred" / "labels" / "val" + + files_img_t = list(dataset_img_t.glob("*.jpg")) + files_img_v = list(dataset_img_v.glob("*.jpg")) + files_lab_t = list(dataset_lab_t.glob("*.txt")) + files_lab_v = list(dataset_lab_v.glob("*.txt")) + + expected_train_set_len = 8 + expected_val_set_len = 2 + assert len(files_img_t) == expected_train_set_len + assert len(files_img_v) == expected_val_set_len + assert len(files_lab_t) == expected_train_set_len + assert len(files_lab_v) == expected_val_set_len + + +def test_vision_tuning(tmp_path: Path) -> None: + source_path = tmp_path / "source.jpg" + img = Image.new("RGB", (100, 100)) + img.save(source_path) + model_path = tmp_path / DEFAULT_MODEL + out_path = tmp_path / "output" + + instance = VisionInstance( + source=source_path, + output_path=out_path, + model=model_path, + device="cpu", + workers=1, + ) + create_dummy_yolo_dataset(tmp_path) + instance.set_training_dataset(tmp_path / "data.yaml") + hyp = instance.tune_hyperparams( + iterations=1, + epochs=1, + imgsz=100, + batch_size=-1, + ) + assert ( + out_path / "tuning" / "results" / "best_hyperparameters.yaml" + ).exists() + assert isinstance(hyp, dict)