Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:
run: |
uv sync
uv run ruff check
uv run pytest tests/
uv run pytest -vv tests/
36 changes: 36 additions & 0 deletions tests/test_apnea_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

np = pytest.importorskip("numpy")
pytest.importorskip("scipy")
pytest.importorskip("helia_edge")

from sleepkit.defines import SleepApnea # noqa: E402
from sleepkit.tasks.apnea.metrics import ( # noqa: E402
compute_apnea_efficiency,
compute_apnea_hypopnea_index,
compute_sleep_apnea_durations,
)


def test_compute_sleep_apnea_durations_counts_segments():
mask = np.array([0, 0, 1, 1, 1, 0, 2, 2], dtype=int)
durations = compute_sleep_apnea_durations(mask)
assert durations == {0: 3, 1: 3, 2: 2}


def test_compute_apnea_efficiency_uses_mapped_classes_only():
durations = {0: 90, 1: 10, 2: 50}
class_map = {
SleepApnea.none: 0,
SleepApnea.hypopnea: 1,
}
efficiency = compute_apnea_efficiency(durations, class_map)
assert efficiency == 0.9


def test_compute_apnea_hypopnea_index_counts_events_over_min_duration():
mask = np.zeros(3600, dtype=int)
mask[100:105] = 1
mask[200:203] = 1
ahi = compute_apnea_hypopnea_index(mask, min_duration=4, sample_rate=1)
assert ahi == 1.0
77 changes: 77 additions & 0 deletions tests/test_serial_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from dataclasses import dataclass

import pytest

pytest.importorskip("helia_edge")

import sleepkit.backends.utils as utils # noqa: E402


@dataclass
class DummyPort:
device: str
vid: int
pid: int
serial_number: str | None = None
manufacturer: str | None = None
product: str | None = None


class DummySerialTransport:
def __init__(self, device, baudrate=None):
self.device = device
self.baudrate = baudrate


def test_find_serial_device_matches_fields(monkeypatch):
ports = [
DummyPort(
device="/dev/ttyUSB0",
vid=1234,
pid=5678,
serial_number="ABC123",
manufacturer="Acme",
product="Widget",
),
DummyPort(
device="/dev/ttyUSB1",
vid=1111,
pid=2222,
serial_number="XYZ789",
manufacturer="Other",
product="Gadget",
),
]

monkeypatch.setattr(utils, "list_ports", lambda: ports)

port = utils._find_serial_device(vid_pid="1234:5678")
assert port is ports[0]

port = utils._find_serial_device(manufacturer="acme", product="widget")
assert port is ports[0]

port = utils._find_serial_device(serial_number="XYZ789")
assert port is ports[1]


def test_get_serial_transport_uses_first_matching_port(monkeypatch):
dummy_port = DummyPort(device="/dev/ttyUSB9", vid=1, pid=2)
monkeypatch.setattr(utils, "_find_serial_device", lambda **kwargs: dummy_port)
monkeypatch.setattr(utils, "SerialTransport", DummySerialTransport)

transport = utils.get_serial_transport(vid_pid="1:2", baudrate=115200)
assert transport.device == "/dev/ttyUSB9"
assert transport.baudrate == 115200


def test_get_serial_transport_times_out(monkeypatch):
monkeypatch.setattr(utils, "_find_serial_device", lambda **kwargs: None)
monkeypatch.setattr(utils, "time", utils.time)

times = iter([0.0, 0.1, 0.2, 10.1])
monkeypatch.setattr(utils.time, "time", lambda: next(times))
monkeypatch.setattr(utils.time, "sleep", lambda _: None)

with pytest.raises(TimeoutError, match="Unable to locate serial port"):
utils.get_serial_transport(timeout=0.3)
24 changes: 24 additions & 0 deletions tests/test_stage_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

np = pytest.importorskip("numpy")
pytest.importorskip("helia_edge")

from sleepkit.defines import SleepStage # noqa: E402
from sleepkit.tasks.stage.metrics import ( # noqa: E402
compute_sleep_efficiency,
compute_sleep_stage_durations,
)


def test_compute_sleep_stage_durations_and_efficiency():
mask = np.array([0, 0, 1, 1, 2, 2, 0], dtype=int)
durations = compute_sleep_stage_durations(mask)
assert durations == {0: 3, 1: 2, 2: 2}

class_map = {
SleepStage.wake: 0,
SleepStage.stage1: 1,
SleepStage.stage2: 2,
}
efficiency = compute_sleep_efficiency(durations, class_map)
assert efficiency == 4 / 7
30 changes: 30 additions & 0 deletions tests/test_taskparams_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path

import pytest

pytest.importorskip("helia_edge")

from sleepkit.defines import TaskParams # noqa: E402


def test_taskparams_resolves_relative_paths(tmp_path):
params = TaskParams(
job_dir=tmp_path,
model_file=Path("model.keras"),
weights_file=Path("weights.ckpt"),
val_file=Path("val.tfrecord"),
test_file=Path("test.tfrecord"),
tflm_file=Path("model_buffer.h"),
)

assert params.model_file == tmp_path / "model.keras"
assert params.weights_file == tmp_path / "weights.ckpt"
assert params.val_file == tmp_path / "val.tfrecord"
assert params.test_file == tmp_path / "test.tfrecord"
assert params.tflm_file == tmp_path / "model_buffer.h"


def test_taskparams_keeps_absolute_paths(tmp_path):
absolute_path = tmp_path / "abs.keras"
params = TaskParams(job_dir=tmp_path, model_file=absolute_path)
assert params.model_file == absolute_path
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.