From 82a99931f2a38222d547bbc8a0e7e2f4ccc85bac Mon Sep 17 00:00:00 2001 From: Adam Page Date: Sun, 18 Jan 2026 21:54:34 +0000 Subject: [PATCH] test: add unit tests for metrics and utils --- .github/workflows/ci.yaml | 2 +- tests/test_apnea_metrics.py | 36 ++++++++++++++++ tests/test_serial_utils.py | 77 ++++++++++++++++++++++++++++++++++ tests/test_stage_metrics.py | 24 +++++++++++ tests/test_taskparams_paths.py | 30 +++++++++++++ uv.lock | 6 +-- 6 files changed, 171 insertions(+), 4 deletions(-) create mode 100644 tests/test_apnea_metrics.py create mode 100644 tests/test_serial_utils.py create mode 100644 tests/test_stage_metrics.py create mode 100644 tests/test_taskparams_paths.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8282d8b..142391d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,4 +29,4 @@ jobs: run: | uv sync uv run ruff check - uv run pytest tests/ + uv run pytest -vv tests/ diff --git a/tests/test_apnea_metrics.py b/tests/test_apnea_metrics.py new file mode 100644 index 0000000..d2fba7d --- /dev/null +++ b/tests/test_apnea_metrics.py @@ -0,0 +1,36 @@ +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("scipy") +pytest.importorskip("helia_edge") + +from sleepkit.defines import SleepApnea # noqa: E402 +from sleepkit.tasks.apnea.metrics import ( # noqa: E402 + compute_apnea_efficiency, + compute_apnea_hypopnea_index, + compute_sleep_apnea_durations, +) + + +def test_compute_sleep_apnea_durations_counts_segments(): + mask = np.array([0, 0, 1, 1, 1, 0, 2, 2], dtype=int) + durations = compute_sleep_apnea_durations(mask) + assert durations == {0: 3, 1: 3, 2: 2} + + +def test_compute_apnea_efficiency_uses_mapped_classes_only(): + durations = {0: 90, 1: 10, 2: 50} + class_map = { + SleepApnea.none: 0, + SleepApnea.hypopnea: 1, + } + efficiency = compute_apnea_efficiency(durations, class_map) + assert efficiency == 0.9 + + +def test_compute_apnea_hypopnea_index_counts_events_over_min_duration(): + mask = np.zeros(3600, dtype=int) + mask[100:105] = 1 + mask[200:203] = 1 + ahi = compute_apnea_hypopnea_index(mask, min_duration=4, sample_rate=1) + assert ahi == 1.0 diff --git a/tests/test_serial_utils.py b/tests/test_serial_utils.py new file mode 100644 index 0000000..b16a77a --- /dev/null +++ b/tests/test_serial_utils.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass + +import pytest + +pytest.importorskip("helia_edge") + +import sleepkit.backends.utils as utils # noqa: E402 + + +@dataclass +class DummyPort: + device: str + vid: int + pid: int + serial_number: str | None = None + manufacturer: str | None = None + product: str | None = None + + +class DummySerialTransport: + def __init__(self, device, baudrate=None): + self.device = device + self.baudrate = baudrate + + +def test_find_serial_device_matches_fields(monkeypatch): + ports = [ + DummyPort( + device="/dev/ttyUSB0", + vid=1234, + pid=5678, + serial_number="ABC123", + manufacturer="Acme", + product="Widget", + ), + DummyPort( + device="/dev/ttyUSB1", + vid=1111, + pid=2222, + serial_number="XYZ789", + manufacturer="Other", + product="Gadget", + ), + ] + + monkeypatch.setattr(utils, "list_ports", lambda: ports) + + port = utils._find_serial_device(vid_pid="1234:5678") + assert port is ports[0] + + port = utils._find_serial_device(manufacturer="acme", product="widget") + assert port is ports[0] + + port = utils._find_serial_device(serial_number="XYZ789") + assert port is ports[1] + + +def test_get_serial_transport_uses_first_matching_port(monkeypatch): + dummy_port = DummyPort(device="/dev/ttyUSB9", vid=1, pid=2) + monkeypatch.setattr(utils, "_find_serial_device", lambda **kwargs: dummy_port) + monkeypatch.setattr(utils, "SerialTransport", DummySerialTransport) + + transport = utils.get_serial_transport(vid_pid="1:2", baudrate=115200) + assert transport.device == "/dev/ttyUSB9" + assert transport.baudrate == 115200 + + +def test_get_serial_transport_times_out(monkeypatch): + monkeypatch.setattr(utils, "_find_serial_device", lambda **kwargs: None) + monkeypatch.setattr(utils, "time", utils.time) + + times = iter([0.0, 0.1, 0.2, 10.1]) + monkeypatch.setattr(utils.time, "time", lambda: next(times)) + monkeypatch.setattr(utils.time, "sleep", lambda _: None) + + with pytest.raises(TimeoutError, match="Unable to locate serial port"): + utils.get_serial_transport(timeout=0.3) diff --git a/tests/test_stage_metrics.py b/tests/test_stage_metrics.py new file mode 100644 index 0000000..e0784c0 --- /dev/null +++ b/tests/test_stage_metrics.py @@ -0,0 +1,24 @@ +import pytest + +np = pytest.importorskip("numpy") +pytest.importorskip("helia_edge") + +from sleepkit.defines import SleepStage # noqa: E402 +from sleepkit.tasks.stage.metrics import ( # noqa: E402 + compute_sleep_efficiency, + compute_sleep_stage_durations, +) + + +def test_compute_sleep_stage_durations_and_efficiency(): + mask = np.array([0, 0, 1, 1, 2, 2, 0], dtype=int) + durations = compute_sleep_stage_durations(mask) + assert durations == {0: 3, 1: 2, 2: 2} + + class_map = { + SleepStage.wake: 0, + SleepStage.stage1: 1, + SleepStage.stage2: 2, + } + efficiency = compute_sleep_efficiency(durations, class_map) + assert efficiency == 4 / 7 diff --git a/tests/test_taskparams_paths.py b/tests/test_taskparams_paths.py new file mode 100644 index 0000000..df404fd --- /dev/null +++ b/tests/test_taskparams_paths.py @@ -0,0 +1,30 @@ +from pathlib import Path + +import pytest + +pytest.importorskip("helia_edge") + +from sleepkit.defines import TaskParams # noqa: E402 + + +def test_taskparams_resolves_relative_paths(tmp_path): + params = TaskParams( + job_dir=tmp_path, + model_file=Path("model.keras"), + weights_file=Path("weights.ckpt"), + val_file=Path("val.tfrecord"), + test_file=Path("test.tfrecord"), + tflm_file=Path("model_buffer.h"), + ) + + assert params.model_file == tmp_path / "model.keras" + assert params.weights_file == tmp_path / "weights.ckpt" + assert params.val_file == tmp_path / "val.tfrecord" + assert params.test_file == tmp_path / "test.tfrecord" + assert params.tflm_file == tmp_path / "model_buffer.h" + + +def test_taskparams_keeps_absolute_paths(tmp_path): + absolute_path = tmp_path / "abs.keras" + params = TaskParams(job_dir=tmp_path, model_file=absolute_path) + assert params.model_file == absolute_path diff --git a/uv.lock b/uv.lock index ff6a1d6..dde7ca7 100644 --- a/uv.lock +++ b/uv.lock @@ -3299,7 +3299,7 @@ wheels = [ [[package]] name = "sleepkit" -version = "0.9.0" +version = "0.10.0" source = { editable = "." } dependencies = [ { name = "argdantic", extra = ["all"] }, @@ -3582,8 +3582,8 @@ name = "tensorflow-metal" version = "1.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six" }, - { name = "wheel" }, + { name = "six", marker = "sys_platform != 'linux'" }, + { name = "wheel", marker = "sys_platform != 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/dc/bf/988b619322d5617a928e7f31cbb1ed8dd7f375f69dfa73dab26409a00382/tensorflow_metal-1.2.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4bece0ecb154b713b9f5ad4aec676a366d312822161e3cf0e1dea737c64cec04", size = 1357400, upload-time = "2025-01-31T00:52:57.924Z" },