Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 120 additions & 35 deletions src/caliscope/core/process_synchronized_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
Expand Down Expand Up @@ -41,6 +41,7 @@ def process_synchronized_recording(
synced_timestamps: SynchronizedTimestamps,
*,
subsample: int = 1,
parallel: bool = True,
on_progress: Callable[[int, int], None] | None = None,
on_frame_data: Callable[[int, dict[int, FrameData]], None] | None = None,
token: CancellationToken | None = None,
Expand All @@ -56,6 +57,9 @@ def process_synchronized_recording(
tracker: Tracker for 2D point extraction (handles per-cam_id state internally)
synced_timestamps: Pre-constructed timestamp alignment object
subsample: Process every Nth sync index (1 = all, 10 = every 10th)
parallel: Process cameras concurrently (True) or serially (False).
Uses ThreadPoolExecutor when True and multiple cameras present.
Set to False as fallback if threading issues are discovered.
on_progress: Called with (current, total) during processing
on_frame_data: Called with (sync_index, {cam_id: FrameData}) for live display
token: Cancellation token for graceful shutdown
Expand All @@ -75,43 +79,79 @@ def process_synchronized_recording(
point_rows: list[dict] = []

try:
for i, sync_index in enumerate(all_sync_indices):
if token is not None and token.is_cancelled:
logger.info("Processing cancelled")
break
use_pool = parallel and len(frame_sources) > 1

frame_data: dict[int, FrameData] = {}
if use_pool:
camera_pool = ThreadPoolExecutor(max_workers=len(frame_sources))
else:
camera_pool = None

for cam_id in synced_timestamps.cam_ids:
frame_index = synced_timestamps.frame_for(sync_index, cam_id)

if frame_index is None:
logger.debug(f"Dropped frame: sync={sync_index}, cam_id={cam_id}")
continue

if cam_id not in frame_sources:
logger.warning(f"cam_id {cam_id} not in cameras dict, skipping")
continue

camera = cameras[cam_id]
frame = frame_sources[cam_id].get_frame(frame_index)

if frame is None:
logger.warning(
f"Failed to read frame: sync={sync_index}, cam_id={cam_id}, frame_index={frame_index}"
)
continue

points = tracker.get_points(frame, cam_id, camera.rotation_count)
frame_data[cam_id] = FrameData(frame, points, frame_index)

frame_time = synced_timestamps.time_for(cam_id, frame_index)
_accumulate_points(point_rows, sync_index, cam_id, frame_index, frame_time, points)
try:
for i, sync_index in enumerate(all_sync_indices):
if token is not None and token.is_cancelled:
logger.info("Processing cancelled")
break

if on_frame_data is not None:
on_frame_data(sync_index, frame_data)
if on_progress is not None:
on_progress(i + 1, total)
frame_data: dict[int, FrameData] = {}

if use_pool and camera_pool is not None:
# --- Parallel path ---
futures: dict[int, Future[tuple[int, FrameData | None, list[dict]]]] = {}
for cam_id in synced_timestamps.cam_ids:
frame_index = synced_timestamps.frame_for(sync_index, cam_id)
if frame_index is None:
continue
if cam_id not in frame_sources:
continue
camera = cameras[cam_id]
frame_time = synced_timestamps.time_for(cam_id, frame_index)
futures[cam_id] = camera_pool.submit(
_process_one_camera,
cam_id,
sync_index,
frame_index,
frame_sources[cam_id],
camera,
tracker,
frame_time,
)

for cam_id, future in futures.items():
_, fd, rows = future.result()
if fd is not None:
frame_data[cam_id] = fd
point_rows.extend(rows)
else:
# --- Serial path (original logic) ---
for cam_id in synced_timestamps.cam_ids:
frame_index = synced_timestamps.frame_for(sync_index, cam_id)
if frame_index is None:
continue
if cam_id not in frame_sources:
continue
camera = cameras[cam_id]
frame = frame_sources[cam_id].read_frame_at(frame_index)
if frame is None:
logger.warning(
f"Failed to read frame: sync={sync_index}, cam_id={cam_id}, frame_index={frame_index}"
)
continue
points = tracker.get_points(frame, cam_id, camera.rotation_count)
frame_data[cam_id] = FrameData(frame, points, frame_index)
frame_time = synced_timestamps.time_for(cam_id, frame_index)
_accumulate_points(point_rows, sync_index, cam_id, frame_index, frame_time, points)

# Threading contract: callbacks are always invoked from this
# thread (the worker thread that owns the sync-index loop),
# never from pool threads. Presenters rely on this guarantee
# for unsynchronized accumulator state.
if on_frame_data is not None:
on_frame_data(sync_index, frame_data)
if on_progress is not None:
on_progress(i + 1, total)
finally:
if camera_pool is not None:
camera_pool.shutdown(wait=False)

finally:
for source in frame_sources.values():
Expand Down Expand Up @@ -224,6 +264,51 @@ def _accumulate_points(
)


def _process_one_camera(
cam_id: int,
sync_index: int,
frame_index: int,
frame_source: FrameSource,
camera: CameraData,
tracker: Tracker,
frame_time: float,
) -> tuple[int, FrameData | None, list[dict]]:
"""Process a single camera for one sync index.

Thread safety: This function is safe to call concurrently for different
cam_ids because:
- Each FrameSource instance is dedicated to one camera (no sharing).
- Tracker.get_points() is thread-safe:
- OnnxTracker._prev_bboxes: keyed by cam_id, each thread accesses
only its own key. Dict internal structure is GIL-protected.
INVARIANT: thread safety depends on each thread accessing a
distinct cam_id. Two threads must never process the same cam_id
concurrently.
- OnnxTracker.session.run(): onnxruntime InferenceSession.run() is
thread-safe (C++ session uses read-only model weights, per-call
buffer allocation). Verified for CPU execution provider.
- CharucoTracker/ArUcoTracker/ChessboardTracker: stateless OpenCV
calls on caller-provided buffers.
- point_rows is built locally and returned, not shared.

Returns:
(cam_id, frame_data_or_none, point_rows)
"""
frame = frame_source.read_frame_at(frame_index)

if frame is None:
logger.warning(f"Failed to read frame: sync={sync_index}, cam_id={cam_id}, frame_index={frame_index}")
return cam_id, None, []

points = tracker.get_points(frame, cam_id, camera.rotation_count)
fd = FrameData(frame, points, frame_index)

local_rows: list[dict] = []
_accumulate_points(local_rows, sync_index, cam_id, frame_index, frame_time, points)

return cam_id, fd, local_rows


def _build_image_points(point_rows: list[dict]) -> ImagePoints:
"""Construct ImagePoints from accumulated point data."""
if not point_rows:
Expand Down
94 changes: 86 additions & 8 deletions src/caliscope/gui/presenters/multi_camera_processing_presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ class MultiCameraProcessingPresenter(QObject):
state_changed = Signal(MultiCameraProcessingState)

# Progress signals
progress_updated = Signal(int, int, int) # (current, total, percent)
progress_updated = Signal(int, int, int, str) # (current, total, percent, eta_string)
thumbnail_updated = Signal(int, object, object) # (cam_id, NDArray frame, PointPacket | None)
coverage_updated = Signal(object) # (NDArray[np.int64] matrix, list[int] cam_ids)

# Completion signals
processing_complete = Signal(object, object, object) # (ImagePoints, ExtrinsicCoverageReport, Tracker)
Expand All @@ -89,6 +90,7 @@ class MultiCameraProcessingPresenter(QObject):

# Thumbnail throttle interval (seconds)
THUMBNAIL_INTERVAL = 0.1 # ~10 FPS
COVERAGE_INTERVAL = 5.0 # seconds between coverage heatmap updates

def __init__(
self,
Expand Down Expand Up @@ -121,6 +123,17 @@ def __init__(
self._thumbnails: dict[int, NDArray[np.uint8]] = {}
self._last_thumbnail_time: float = 0.0

# ETA and coverage state
self._processing_start_time: float = 0.0
self._last_coverage_time: float = 0.0
self._sync_frames_done: int = 0
self._frames_total: int = 0

# Incremental coverage matrix — updated per sync index, never rebuilt
self._coverage_matrix: np.ndarray | None = None
self._coverage_cam_ids: list[int] = []
self._coverage_cam_id_to_index: dict[int, int] = {}

# -------------------------------------------------------------------------
# Public Properties
# -------------------------------------------------------------------------
Expand Down Expand Up @@ -297,6 +310,18 @@ def start_processing(self, subsample: int = 1) -> None:
# Clear previous results only after successful timestamp construction
self._reset_results()

self._processing_start_time = time.time()
self._last_coverage_time = 0.0
self._sync_frames_done = 0
self._frames_total = 0

# Initialize incremental coverage matrix
cam_ids = sorted(cameras.keys())
self._coverage_cam_ids = cam_ids
self._coverage_cam_id_to_index = {cid: idx for idx, cid in enumerate(cam_ids)}
n = len(cam_ids)
self._coverage_matrix = np.zeros((n, n), dtype=np.int64)

def worker(token: CancellationToken, handle: TaskHandle) -> ImagePoints:
return process_synchronized_recording(
recording_dir=recording_dir,
Expand Down Expand Up @@ -383,23 +408,71 @@ def cleanup(self) -> None:
# -------------------------------------------------------------------------

def _on_progress(self, current: int, total: int) -> None:
"""Progress callback from process_synchronized_recording."""
"""Progress callback from process_synchronized_recording.

Called from worker thread. Thread-safe because progress_updated
is a Qt signal (cross-thread emission handled by Qt event loop).
"""
self._frames_total = total
percent = int(100 * current / total) if total > 0 else 0
self.progress_updated.emit(current, total, percent)

# ETA computation (wait 3 seconds for rate to stabilize)
elapsed = time.time() - self._processing_start_time
eta_str = ""
if elapsed > 3.0 and current > 0:
rate = current / elapsed
remaining = max(0.0, (total - current) / rate)
minutes, seconds = divmod(int(remaining), 60)
eta_str = f" — ~{minutes}:{seconds:02d} remaining"

self.progress_updated.emit(current, total, percent, eta_str)

def _on_frame_data(self, sync_index: int, frame_data: dict[int, FrameData]) -> None:
"""Frame data callback from process_synchronized_recording.

Throttled to ~10 FPS to avoid overwhelming the UI.
Called from worker thread. Accumulator state is not protected by
locks because this callback is guaranteed single-threaded by the
processing function's contract (callbacks are invoked from the
main worker thread, never from pool threads).

Throttled thumbnail emission at ~10 FPS.
Incremental coverage matrix update on every call.
Coverage signal emission throttled to COVERAGE_INTERVAL.
"""
now = time.time()
self._sync_frames_done += 1

# --- Incremental coverage matrix update (cheap, every call) ---
if self._coverage_matrix is not None:
# Collect which cameras saw each point_id at this sync_index
point_cameras: dict[int, list[int]] = {}
for cam_id, data in frame_data.items():
if data.points is not None and len(data.points.point_id) > 0:
for pid in data.points.point_id:
point_cameras.setdefault(int(pid), []).append(cam_id)

# Increment pairwise counts
for pid, cam_list in point_cameras.items():
cam_list_sorted = sorted(cam_list)
for i, cam_id_i in enumerate(cam_list_sorted):
for cam_id_j in cam_list_sorted[i:]:
if cam_id_i in self._coverage_cam_id_to_index and cam_id_j in self._coverage_cam_id_to_index:
idx_i = self._coverage_cam_id_to_index[cam_id_i]
idx_j = self._coverage_cam_id_to_index[cam_id_j]
self._coverage_matrix[idx_i, idx_j] += 1
if idx_i != idx_j:
self._coverage_matrix[idx_j, idx_i] += 1

# --- Emit coverage snapshot (throttled) ---
if now - self._last_coverage_time >= self.COVERAGE_INTERVAL:
self._last_coverage_time = now
if self._coverage_matrix is not None:
self.coverage_updated.emit((self._coverage_matrix.copy(), list(self._coverage_cam_ids)))

# --- Thumbnails (throttled, ~10 FPS) ---
if now - self._last_thumbnail_time < self.THUMBNAIL_INTERVAL:
return

self._last_thumbnail_time = now

# Update thumbnails for all cameras in this sync packet
# Points passed to View for overlay rendering (MVP: View renders)
for cam_id, data in frame_data.items():
self._thumbnails[cam_id] = data.frame
self.thumbnail_updated.emit(cam_id, data.frame, data.points)
Expand Down Expand Up @@ -438,6 +511,11 @@ def _reset_results(self) -> None:
self._result = None
self._coverage_report = None
self._task_handle = None
self._sync_frames_done = 0
self._frames_total = 0
self._coverage_matrix = None
self._coverage_cam_ids = []
self._coverage_cam_id_to_index = {}

def _emit_state_changed(self) -> None:
"""Emit state_changed signal with current computed state."""
Expand Down
20 changes: 15 additions & 5 deletions src/caliscope/gui/views/multi_camera_processing_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ def _connect_signals(self) -> None:
self._presenter.thumbnail_updated.connect(self._on_thumbnail_updated)
self._presenter.processing_complete.connect(self._on_processing_complete)
self._presenter.processing_failed.connect(self._on_processing_failed)
self._presenter.coverage_updated.connect(
self._on_coverage_updated,
Qt.ConnectionType.QueuedConnection,
)

def _update_ui_for_state(self, state: "MultiCameraProcessingState") -> None:
"""Update UI elements based on presenter state."""
Expand Down Expand Up @@ -325,9 +329,9 @@ def _update_ui_for_state(self, state: "MultiCameraProcessingState") -> None:
self._progress_label.setText("Starting...")
self._set_rotation_enabled(False)
self._subsample_spin.setEnabled(False)
# Keep showing placeholder during processing
self._coverage_placeholder.show()
self._coverage_content.hide()
# Show coverage content for live heatmap updates
self._coverage_placeholder.hide()
self._coverage_content.show()

elif state == MultiCameraProcessingState.COMPLETE:
self._action_btn.setText("Reset")
Expand All @@ -349,10 +353,16 @@ def _set_rotation_enabled(self, enabled: bool) -> None:
# Slots for Presenter Signals
# -------------------------------------------------------------------------

def _on_progress_updated(self, current: int, total: int, percent: int) -> None:
def _on_progress_updated(self, current: int, total: int, percent: int, eta_str: str) -> None:
"""Handle progress update from presenter."""
self._progress_bar.setValue(percent)
self._progress_label.setText(f"Processing: {current}/{total} frames ({percent}%)")
self._progress_label.setText(f"Processing: {current}/{total} frames ({percent}%){eta_str}")

def _on_coverage_updated(self, data: tuple) -> None:
"""Update the coverage heatmap during processing."""
matrix, cam_ids = data
labels = [f"C{cid}" for cid in cam_ids]
self._coverage_heatmap.set_data(matrix, killed_linkages=set(), labels=labels)

def _on_thumbnail_updated(self, cam_id: int, frame: NDArray, points: "PointPacket | None") -> None:
"""Handle thumbnail update from presenter.
Expand Down
Loading
Loading