diff --git a/roboclaw/data/curation/paths.py b/roboclaw/data/curation/paths.py index 7742a68e..788fc26d 100644 --- a/roboclaw/data/curation/paths.py +++ b/roboclaw/data/curation/paths.py @@ -1,16 +1,3 @@ from __future__ import annotations -from pathlib import Path - - -def datasets_root() -> Path: - from roboclaw.embodied.embodiment.manifest.helpers import get_roboclaw_home, get_manifest_path - import json - - manifest_path = get_manifest_path() - if manifest_path.exists(): - data = json.loads(manifest_path.read_text(encoding="utf-8")) - root = data.get("datasets", {}).get("root", "") - if root: - return Path(root).expanduser() - return get_roboclaw_home() / "workspace" / "embodied" / "datasets" +from roboclaw.data.paths import datasets_root diff --git a/roboclaw/data/datasets.py b/roboclaw/data/datasets.py index 87b07143..a8be21c6 100644 --- a/roboclaw/data/datasets.py +++ b/roboclaw/data/datasets.py @@ -14,7 +14,7 @@ from loguru import logger from roboclaw.data.curation.features import extract_action_names, extract_state_names -from roboclaw.data.curation.paths import datasets_root +from roboclaw.data.paths import datasets_root DatasetKind = Literal["local", "remote"] ImportStatus = Literal["queued", "running", "completed", "error"] diff --git a/roboclaw/data/paths.py b/roboclaw/data/paths.py new file mode 100644 index 00000000..935cfe21 --- /dev/null +++ b/roboclaw/data/paths.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import json +from pathlib import Path + + +def datasets_root() -> Path: + from roboclaw.embodied.embodiment.manifest.helpers import get_manifest_path, get_roboclaw_home + + manifest_path = get_manifest_path() + if manifest_path.exists(): + data = json.loads(manifest_path.read_text(encoding="utf-8")) + root = data.get("datasets", {}).get("root", "") + if root: + return Path(root).expanduser() + return get_roboclaw_home() / "workspace" / "embodied" / "datasets" diff --git a/roboclaw/data/repair/__init__.py b/roboclaw/data/repair/__init__.py new file mode 100644 index 00000000..757d75ae --- /dev/null +++ b/roboclaw/data/repair/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .types import DamageType, DiagnosisResult, RepairResult + +__all__ = ["DamageType", "DiagnosisResult", "RepairResult"] diff --git a/roboclaw/data/repair/boundary.py b/roboclaw/data/repair/boundary.py new file mode 100644 index 00000000..c47082eb --- /dev/null +++ b/roboclaw/data/repair/boundary.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import csv +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from PIL import Image + +from roboclaw.data.paths import datasets_root + +from .io import load_info +from .lerobot_adapter import LeRobotDatasetAdapter + + +def resolve_dataset_root(dataset: str, root: Path | None) -> Path: + dataset_path = Path(dataset).expanduser() + if dataset_path.exists(): + return dataset_path.resolve() + + base_root = root.expanduser().resolve() if root is not None else datasets_root().resolve() + candidates = [base_root / dataset] + if "/" not in dataset.strip("/"): + candidates.append(base_root / "local" / dataset) + for candidate in candidates: + if candidate.exists(): + return candidate.resolve() + searched = "\n".join(f"- {candidate}" for candidate in candidates) + raise FileNotFoundError(f"Dataset path not found. Searched:\n{searched}") + + +def resolve_repo_id(dataset_root: Path, root: Path | None) -> str: + if root is None: + return f"local/{dataset_root.name}" + + base_root = root.expanduser().resolve() + if dataset_root.is_relative_to(base_root): + return dataset_root.relative_to(base_root).as_posix() + return f"local/{dataset_root.name}" + + +def get_camera_keys(info: dict[str, Any]) -> list[str]: + return [ + key + for key, value in sorted(info.get("features", {}).items()) + if key.startswith("observation.images.") and value.get("dtype") in {"image", "video"} + ] + + +def get_pixel_count(feature: dict[str, Any]) -> int: + shape = feature.get("shape", (0, 0, 0)) + names = list(feature.get("names") or []) + if names: + h_idx = names.index("height") if "height" in names else None + w_idx = names.index("width") if "width" in names else None + if h_idx is not None and w_idx is not None: + return int(shape[h_idx]) * int(shape[w_idx]) + return int(shape[0]) * int(shape[1]) if len(shape) >= 2 else 0 + + +def select_camera_key(info: dict[str, Any], camera_key: str | None) -> str: + camera_keys = get_camera_keys(info) + if not camera_keys: + raise ValueError("Dataset does not contain any image or video observation keys.") + if camera_key is not None: + if camera_key not in camera_keys: + available = "\n".join(f"- {key}" for key in camera_keys) + raise ValueError(f"Unknown camera key: {camera_key}\nAvailable cameras:\n{available}") + return camera_key + + features = info["features"] + return sorted( + camera_keys, + key=lambda key: (0 if "front" in key else 1, -get_pixel_count(features[key]), key), + )[0] + + +def parse_episode_indices(spec: str, total_episodes: int) -> list[int]: + if total_episodes <= 0: + raise ValueError("Dataset has no episodes to export.") + if spec.strip().lower() == "all": + return list(range(total_episodes)) + + episode_indices: list[int] = [] + seen: set[int] = set() + for chunk in spec.split(","): + part = chunk.strip() + if not part: + continue + if "-" in part: + start_str, end_str = part.split("-", maxsplit=1) + start = int(start_str) + end = int(end_str) + if end < start: + raise ValueError(f"Invalid episode range: {part}") + values = range(start, end + 1) + else: + values = [int(part)] + for episode_index in values: + if episode_index < 0 or episode_index >= total_episodes: + raise ValueError( + f"Episode index {episode_index} is out of range for dataset with {total_episodes} episodes." + ) + if episode_index not in seen: + episode_indices.append(episode_index) + seen.add(episode_index) + if not episode_indices: + raise ValueError("No episodes were selected.") + return episode_indices + + +def prepare_output_dir(output_dir: Path, overwrite: bool) -> None: + if output_dir.exists() and overwrite: + shutil.rmtree(output_dir) + if output_dir.exists() and any(output_dir.iterdir()): + raise FileExistsError(f"Output directory already exists and is not empty: {output_dir}") + output_dir.mkdir(parents=True, exist_ok=True) + + +def format_episode_spec(episode_indices: list[int]) -> str: + if not episode_indices: + return "" + groups: list[str] = [] + start = episode_indices[0] + prev = start + for index in episode_indices[1:]: + if index == prev + 1: + prev = index + continue + groups.append(str(start) if start == prev else f"{start}-{prev}") + start = index + prev = index + groups.append(str(start) if start == prev else f"{start}-{prev}") + return ",".join(groups) + + +@dataclass(frozen=True) +class BoundaryFrameExportRequest: + dataset: str + output_dir: Path + episodes: str = "all" + camera_key: str | None = None + root: Path | None = None + overwrite: bool = False + + +@dataclass(frozen=True) +class BoundaryFrameExportResult: + dataset_root: Path + repo_id: str + camera_key: str + episodes_exported: int + manifest_path: Path + + +def export_episode_boundary_frames( + *, + dataset: Any, + output_dir: Path, + episode_indices: list[int], + camera_key: str, +) -> Path: + pad_width = max(3, len(str(max(episode_indices)))) + episodes = dataset.meta.episodes + manifest_rows: list[dict[str, Any]] = [] + for episode_index in episode_indices: + from_index = int(episodes["dataset_from_index"][episode_index]) + to_index = int(episodes["dataset_to_index"][episode_index]) + length = int(episodes["length"][episode_index]) + success = str(episodes["episode_success"][episode_index]) if "episode_success" in episodes.column_names else "" + + first_name = f"episode_{episode_index:0{pad_width}d}_first.png" + last_name = f"episode_{episode_index:0{pad_width}d}_last.png" + frame_to_pil_image(dataset[from_index][camera_key]).save(output_dir / first_name) + frame_to_pil_image(dataset[to_index - 1][camera_key]).save(output_dir / last_name) + + manifest_rows.append( + { + "episode_index": episode_index, + "length": length, + "episode_success": success, + "first_dataset_index": from_index, + "last_dataset_index": to_index - 1, + "first_file": first_name, + "last_file": last_name, + } + ) + + manifest_path = output_dir / "manifest.csv" + with manifest_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter( + handle, + fieldnames=[ + "episode_index", + "length", + "episode_success", + "first_dataset_index", + "last_dataset_index", + "first_file", + "last_file", + ], + ) + writer.writeheader() + writer.writerows(manifest_rows) + + (output_dir / "README.txt").write_text( + f"Dataset root: {dataset.root}\n" + f"Dataset repo_id: {dataset.repo_id}\n" + f"Camera: {camera_key}\n" + f"Episodes exported: {format_episode_spec(episode_indices)} ({len(episode_indices)} total)\n" + "Files per episode: episode_XXX_first.png, episode_XXX_last.png\n", + encoding="utf-8", + ) + return manifest_path + + +def frame_to_pil_image(frame: Any) -> Image.Image: + if isinstance(frame, Image.Image): + return frame.copy() + if isinstance(frame, torch.Tensor): + array = frame.detach().cpu().numpy() + else: + array = np.asarray(frame) + + if array.ndim == 3 and array.shape[0] in {1, 3, 4}: + array = np.moveaxis(array, 0, -1) + if array.dtype != np.uint8: + array = np.clip(array, 0, 255).astype(np.uint8) + return Image.fromarray(array) + + +class BoundaryFrameExporter: + def __init__(self, adapter: LeRobotDatasetAdapter | None = None) -> None: + self._adapter = adapter or LeRobotDatasetAdapter() + + def export(self, request: BoundaryFrameExportRequest) -> BoundaryFrameExportResult: + dataset_root = resolve_dataset_root(request.dataset, request.root) + repo_id = resolve_repo_id(dataset_root, request.root) + info = load_info(dataset_root) + camera_key = select_camera_key(info, request.camera_key) + episode_indices = parse_episode_indices(request.episodes, int(info["total_episodes"])) + prepare_output_dir(request.output_dir, request.overwrite) + dataset = self._adapter.open_dataset(repo_id=repo_id, root=dataset_root) + manifest_path = export_episode_boundary_frames( + dataset=dataset, + output_dir=request.output_dir, + episode_indices=episode_indices, + camera_key=camera_key, + ) + return BoundaryFrameExportResult( + dataset_root=dataset_root, + repo_id=repo_id, + camera_key=camera_key, + episodes_exported=len(episode_indices), + manifest_path=manifest_path, + ) diff --git a/roboclaw/data/repair/diagnosis.py b/roboclaw/data/repair/diagnosis.py new file mode 100644 index 00000000..8d402034 --- /dev/null +++ b/roboclaw/data/repair/diagnosis.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + +from .io import ( + build_video_path, + count_images_per_camera, + count_video_files, + find_log_for_dataset, + get_video_keys, + load_info, + min_images_per_camera, + parse_cp_from_log, + read_recovery_rows, + safe_read_parquet_metadata, + safe_read_parquet_table, + scan_parquet_files, +) +from .types import DamageType, DiagnosisResult + +log = logging.getLogger(__name__) + + +def find_tmp_videos(dataset_dir: Path) -> dict[str, Path]: + result: dict[str, Path] = {} + for tmp_dir in sorted(dataset_dir.iterdir()): + if not tmp_dir.is_dir() or not tmp_dir.name.startswith("tmp"): + continue + for mp4_path in tmp_dir.glob("*.mp4"): + parts = mp4_path.stem.rsplit("_", 1) + video_key = parts[0] if len(parts) == 2 and parts[1].isdigit() else mp4_path.stem + result[video_key] = mp4_path + return result + + +def has_frame_mismatch(recovery_count: int, images_per_camera: dict[str, int]) -> bool: + if recovery_count <= 0 or not images_per_camera: + return False + return any(count != recovery_count for count in images_per_camera.values()) + + +def truncate_target_frames(n_recovery_lines: int, image_floor: int, n_parquet_rows: int) -> int: + candidates = [value for value in [n_recovery_lines, image_floor, n_parquet_rows] if value > 0] + return min(candidates) if candidates else 0 + + +def is_repairable(damage_type: DamageType, details: dict[str, Any]) -> bool: + if damage_type == DamageType.CRASH_NO_SAVE: + return details["n_recovery_lines"] > 0 and details["min_images_per_camera"] > 0 + if damage_type == DamageType.TMP_VIDEOS_STUCK: + return details["n_recovery_lines"] > 0 and details["n_tmp_videos"] > 0 + if damage_type == DamageType.PARQUET_NO_VIDEO: + return details["n_parquet_rows"] > 0 and details["min_images_per_camera"] > 0 + if damage_type == DamageType.META_STALE: + return details["n_parquet_rows"] > 0 + if damage_type == DamageType.FRAME_MISMATCH: + return details["truncate_target_frames"] > 0 + if damage_type == DamageType.MISSING_CP: + return details.get("n_log_cp", 0) > 0 + return False + + +def _classify_damage( + total_episodes: int, + n_recovery_lines: int, + image_floor: int, + n_parquet_rows: int, + n_video_files: int, + video_keys: list[str], + tmp_videos: dict[str, Path], + images_per_camera: dict[str, int], + has_cp: bool, + log_cp_intervals: list[dict[str, Any]], +) -> DamageType: + if total_episodes == 0 and n_parquet_rows == 0 and n_recovery_lines == 0 and image_floor == 0 and not tmp_videos: + return DamageType.EMPTY_SHELL + if n_video_files == 0 and tmp_videos and n_recovery_lines > 0: + return DamageType.TMP_VIDEOS_STUCK + if n_parquet_rows == 0 and n_video_files == 0 and (n_recovery_lines > 0 or image_floor > 0): + return DamageType.CRASH_NO_SAVE + if n_parquet_rows > 0 and n_video_files == 0 and video_keys: + return DamageType.PARQUET_NO_VIDEO + if total_episodes == 0 and n_parquet_rows > 0 and n_video_files > 0: + return DamageType.META_STALE + if has_frame_mismatch(n_recovery_lines, images_per_camera): + return DamageType.FRAME_MISMATCH + if not has_cp and log_cp_intervals and n_parquet_rows > 0: + return DamageType.MISSING_CP + return DamageType.HEALTHY + + +class DatasetDiagnosisService: + def diagnose(self, dataset_dir: Path) -> DiagnosisResult: + info = load_info(dataset_dir) + total_episodes = int(info.get("total_episodes", 0)) + total_frames = int(info.get("total_frames", 0)) + n_recovery_lines = len(read_recovery_rows(dataset_dir)) + images_per_camera = count_images_per_camera(dataset_dir) + image_floor = min_images_per_camera(images_per_camera) + n_parquet_files, _episode_count, n_parquet_rows = scan_parquet_files(dataset_dir) + n_video_files = count_video_files(dataset_dir) + video_keys = get_video_keys(info) + tmp_videos = find_tmp_videos(dataset_dir) + has_cp = (dataset_dir / "critical_phase_intervals.json").exists() + log_path = find_log_for_dataset(dataset_dir) + log_cp_intervals = parse_cp_from_log(log_path) if log_path and not has_cp else [] + + details: dict[str, Any] = { + "info_total_episodes": total_episodes, + "info_total_frames": total_frames, + "n_recovery_lines": n_recovery_lines, + "images_per_camera": images_per_camera, + "min_images_per_camera": image_floor, + "n_parquet_files": n_parquet_files, + "n_parquet_rows": n_parquet_rows, + "n_video_files": n_video_files, + "n_video_keys": len(video_keys), + "n_tmp_videos": len(tmp_videos), + "tmp_videos": tmp_videos, + "truncate_target_frames": truncate_target_frames( + n_recovery_lines=n_recovery_lines, + image_floor=image_floor, + n_parquet_rows=n_parquet_rows, + ), + "has_cp": has_cp, + "n_log_cp": len(log_cp_intervals), + "log_cp_intervals": log_cp_intervals, + "log_path": log_path, + } + + damage_type = _classify_damage( + total_episodes=total_episodes, + n_recovery_lines=n_recovery_lines, + image_floor=image_floor, + n_parquet_rows=n_parquet_rows, + n_video_files=n_video_files, + video_keys=video_keys, + tmp_videos=tmp_videos, + images_per_camera=images_per_camera, + has_cp=has_cp, + log_cp_intervals=log_cp_intervals, + ) + return DiagnosisResult( + dataset_dir=dataset_dir, + damage_type=damage_type, + repairable=is_repairable(damage_type, details), + details=details, + ) + + def verify(self, dataset_dir: Path) -> list[str]: + errors: list[str] = [] + info_path = dataset_dir / "meta" / "info.json" + if not info_path.exists(): + return ["info.json missing"] + + try: + info = load_info(dataset_dir) + except (json.JSONDecodeError, OSError) as exc: + log.exception("Unable to read %s", info_path) + return [f"info.json unreadable: {exc}"] + + total_episodes = int(info.get("total_episodes", 0)) + total_frames = int(info.get("total_frames", 0)) + if total_episodes <= 0: + errors.append(f"total_episodes={total_episodes} (expected > 0)") + if total_frames <= 0: + errors.append(f"total_frames={total_frames} (expected > 0)") + + parquet_rows = 0 + for parquet_path in sorted((dataset_dir / "data").rglob("*.parquet")): + metadata = safe_read_parquet_metadata(parquet_path) + table = safe_read_parquet_table(parquet_path) + if metadata is None or table is None: + errors.append(f"unreadable parquet: {parquet_path.relative_to(dataset_dir)}") + continue + parquet_rows += metadata.num_rows + + if parquet_rows != total_frames: + errors.append(f"parquet row sum {parquet_rows} != info total_frames {total_frames}") + + for video_key in get_video_keys(info): + for episode_index in range(total_episodes): + video_path = build_video_path(dataset_dir, info, video_key, episode_index) + if not video_path.exists(): + errors.append(f"missing video: {video_path.relative_to(dataset_dir)}") + + return errors + + +_DIAGNOSIS_SERVICE = DatasetDiagnosisService() + + +def diagnose_dataset(dataset_dir: Path) -> DiagnosisResult: + return _DIAGNOSIS_SERVICE.diagnose(dataset_dir) + + +def verify_repaired_dataset(dataset_dir: Path) -> list[str]: + return _DIAGNOSIS_SERVICE.verify(dataset_dir) diff --git a/roboclaw/data/repair/io.py b/roboclaw/data/repair/io.py new file mode 100644 index 00000000..d55c1f4d --- /dev/null +++ b/roboclaw/data/repair/io.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import copy +import json +import logging +import re +from pathlib import Path +from typing import Any + +import numpy as np +import PIL.Image +import pyarrow as pa +import pyarrow.parquet as pq + +log = logging.getLogger(__name__) + +DEFAULT_VIDEO_PATH = "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4" + +_CP_END_RE = re.compile( + r"\[CP\] END at episode (\d+), frame (\d+) " + r"\(segment: (\d+)-(\d+), \d+ frames(?:, outcome=(\w+))?\)" +) +_PARQUET_ERRORS = (OSError, pa.lib.ArrowException) + + +def sanitize_jsonl_line(line: str) -> str: + return line.replace("\x00", "").strip() + + +def parse_cp_from_log(log_path: Path) -> list[dict[str, Any]]: + intervals: list[dict[str, Any]] = [] + with log_path.open() as handle: + for line in handle: + match = _CP_END_RE.search(line) + if match is None: + continue + intervals.append( + { + "episode_index": int(match.group(1)), + "start_frame": int(match.group(3)), + "end_frame": int(match.group(4)), + "outcome": match.group(5), + } + ) + return intervals + + +def find_log_for_dataset(dataset_dir: Path) -> Path | None: + log_path = dataset_dir.parent / f"{dataset_dir.name}.log" + return log_path if log_path.exists() else None + + +def load_info(dataset_dir: Path) -> dict[str, Any]: + with (dataset_dir / "meta" / "info.json").open() as handle: + return json.load(handle) + + +def write_info(dataset_dir: Path, info: dict[str, Any]) -> None: + info_path = dataset_dir / "meta" / "info.json" + info_path.write_text(json.dumps(info, indent=4) + "\n", encoding="utf-8") + + +def read_recovery_rows(dataset_dir: Path) -> list[dict[str, Any]]: + recovery_path = dataset_dir / "recovery_frames.jsonl" + if not recovery_path.exists(): + return [] + + rows: list[dict[str, Any]] = [] + with recovery_path.open() as handle: + for line_number, line in enumerate(handle, start=1): + sanitized = sanitize_jsonl_line(line) + if not sanitized: + break + try: + rows.append(json.loads(sanitized)) + except json.JSONDecodeError as exc: + log.warning("Corrupt JSON in %s at line %d: %s", recovery_path, line_number, exc) + break + return rows + + +def normalize_feature_shapes(features: dict[str, Any]) -> dict[str, Any]: + normalized = copy.deepcopy(features) + for feature in normalized.values(): + if "shape" in feature: + feature["shape"] = tuple(feature["shape"]) + return normalized + + +def coerce_recovery_value(value: Any, feature: dict[str, Any]) -> Any: + if isinstance(value, list): + return np.array(value, dtype=np.dtype(feature["dtype"])) + return value + + +def list_episode_dirs(parent: Path) -> list[Path]: + if not parent.exists(): + return [] + return [path for path in sorted(parent.iterdir()) if path.is_dir()] + + +def list_frame_pngs(episode_dir: Path) -> list[Path]: + return sorted(episode_dir.glob("frame-*.png")) + + +def count_images_per_camera(dataset_dir: Path) -> dict[str, int]: + counts: dict[str, int] = {} + images_dir = dataset_dir / "images" + for camera_dir in list_episode_dirs(images_dir): + counts[camera_dir.name] = sum( + len(list_frame_pngs(episode_dir)) for episode_dir in list_episode_dirs(camera_dir) + ) + return counts + + +def min_images_per_camera(images_per_camera: dict[str, int]) -> int: + return min(images_per_camera.values()) if images_per_camera else 0 + + +def count_video_files(dataset_dir: Path) -> int: + videos_dir = dataset_dir / "videos" + return len(list(videos_dir.rglob("*.mp4"))) if videos_dir.exists() else 0 + + +def get_video_keys(info: dict[str, Any]) -> list[str]: + return [key for key, value in info["features"].items() if value.get("dtype") == "video"] + + +def get_visual_keys(info: dict[str, Any]) -> list[str]: + return [ + key for key, value in info["features"].items() if value.get("dtype") in {"image", "video"} + ] + + +def safe_read_parquet_metadata(path: Path) -> pq.FileMetaData | None: + try: + return pq.read_metadata(path) + except _PARQUET_ERRORS as exc: + log.warning("Corrupt parquet file (cannot read metadata): %s: %s", path, exc) + return None + + +def safe_read_parquet_table(path: Path, columns: list[str] | None = None) -> pa.Table | None: + try: + return pq.read_table(path, columns=columns) + except _PARQUET_ERRORS as exc: + log.warning("Corrupt parquet file (cannot read table): %s: %s", path, exc) + return None + + +def scan_parquet_files(dataset_dir: Path) -> tuple[int, int, int]: + parquet_files = sorted((dataset_dir / "data").rglob("*.parquet")) + episode_indices: set[int] = set() + total_rows = 0 + valid_files = 0 + for parquet_path in parquet_files: + metadata = safe_read_parquet_metadata(parquet_path) + if metadata is None: + continue + total_rows += metadata.num_rows + valid_files += 1 + table = safe_read_parquet_table(parquet_path, columns=["episode_index"]) + if table is not None: + episode_indices.update(int(value) for value in table["episode_index"].to_pylist()) + return valid_files, len(episode_indices), total_rows + + +def load_png_copy(png_path: Path) -> PIL.Image.Image | None: + try: + with PIL.Image.open(png_path) as image: + return image.copy() + except (OSError, PIL.UnidentifiedImageError): + log.warning("Corrupt PNG: %s", png_path) + return None + + +def build_video_path( + dataset_dir: Path, + info: dict[str, Any], + video_key: str, + episode_index: int, +) -> Path: + chunks_size = int(info.get("chunks_size", 1000)) + template = info.get("video_path") or DEFAULT_VIDEO_PATH + chunk_index = episode_index // chunks_size + file_index = episode_index % chunks_size + return dataset_dir / template.format( + video_key=video_key, + chunk_index=chunk_index, + file_index=file_index, + ) + + +def is_dataset_dir(path: Path) -> bool: + return (path / "meta" / "info.json").exists() + + +def find_datasets(target: Path) -> list[Path]: + if is_dataset_dir(target): + return [target] + if not target.is_dir(): + return [] + return [entry for entry in sorted(target.iterdir()) if entry.is_dir() and is_dataset_dir(entry)] diff --git a/roboclaw/data/repair/lerobot_adapter.py b/roboclaw/data/repair/lerobot_adapter.py new file mode 100644 index 00000000..a44f5cdd --- /dev/null +++ b/roboclaw/data/repair/lerobot_adapter.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + + +class LeRobotDatasetAdapter: + def create_dataset( + self, + *, + repo_id: str, + fps: int, + root: Path, + robot_type: str | None, + features: dict[str, Any], + use_videos: bool, + vcodec: str, + ) -> Any: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + return LeRobotDataset.create( + repo_id=repo_id, + fps=fps, + root=root, + robot_type=robot_type, + features=features, + use_videos=use_videos, + vcodec=vcodec, + ) + + def open_dataset(self, *, repo_id: str, root: Path) -> Any: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + return LeRobotDataset(repo_id=repo_id, root=root) + + def encode_video_frames(self, *, frames_dir: Path, video_path: Path, fps: int, vcodec: str) -> None: + from lerobot.datasets.video_utils import encode_video_frames + + encode_video_frames(frames_dir, video_path, fps, vcodec=vcodec, overwrite=True) diff --git a/roboclaw/data/repair/repairers.py b/roboclaw/data/repair/repairers.py new file mode 100644 index 00000000..70616ce9 --- /dev/null +++ b/roboclaw/data/repair/repairers.py @@ -0,0 +1,439 @@ +from __future__ import annotations + +import json +import logging +import shutil +from pathlib import Path +from typing import Any + +from .diagnosis import verify_repaired_dataset +from .io import ( + DEFAULT_VIDEO_PATH, + build_video_path, + coerce_recovery_value, + get_video_keys, + get_visual_keys, + list_episode_dirs, + list_frame_pngs, + load_info, + load_png_copy, + normalize_feature_shapes, + read_recovery_rows, + safe_read_parquet_metadata, + safe_read_parquet_table, + sanitize_jsonl_line, + scan_parquet_files, + write_info, +) +from .lerobot_adapter import LeRobotDatasetAdapter +from .types import SKIP_FRAME_KEYS, DamageType, DiagnosisResult, RepairResult + +log = logging.getLogger(__name__) + + +def prepare_output_dir(dataset_dir: Path, force: bool) -> tuple[Path, bool]: + out_dir = dataset_dir.parent / f"{dataset_dir.name}_repaired" + if out_dir.exists() and force: + shutil.rmtree(out_dir) + return out_dir, True + if out_dir.exists(): + return out_dir, False + return out_dir, True + + +def get_single_episode_name(images_dir: Path, image_key: str) -> str: + episode_dirs = list_episode_dirs(images_dir / image_key) + if not episode_dirs: + raise FileNotFoundError(f"No episode image directories found for {image_key} under {images_dir}") + return episode_dirs[0].name + + +def build_frame_dict( + *, + recovery_row: dict[str, Any], + features: dict[str, Any], + images_dir: Path, + image_keys: list[str], + episode_name_by_key: dict[str, str], + frame_index: int, + task: str, +) -> dict[str, Any]: + frame: dict[str, Any] = {"task": task} + for key, feature in features.items(): + if key in SKIP_FRAME_KEYS: + continue + if key in image_keys: + png_path = images_dir / key / episode_name_by_key[key] / f"frame-{frame_index:06d}.png" + frame[key] = load_png_copy(png_path) + continue + if key in recovery_row: + frame[key] = coerce_recovery_value(recovery_row[key], feature) + return frame + + +def copy_critical_phase_intervals(src_dir: Path, dst_dir: Path, max_frames: int | None = None) -> None: + src_path = src_dir / "critical_phase_intervals.json" + if not src_path.exists(): + return + intervals = json.loads(src_path.read_text(encoding="utf-8")) + if max_frames is not None: + truncated: list[dict[str, Any]] = [] + for interval in intervals: + item = dict(interval) + if item["start_frame"] >= max_frames: + continue + if item["end_frame"] > max_frames: + item["end_frame"] = max_frames + truncated.append(item) + intervals = truncated + (dst_dir / "critical_phase_intervals.json").write_text( + json.dumps(intervals, indent=2) + "\n", + encoding="utf-8", + ) + + +def parse_episode_index(episode_dir: Path) -> int: + return int(episode_dir.name.split("-")[-1]) + + +def patch_episodes_video_columns(dataset_dir: Path, video_keys: list[str], n_frames: int, fps: int) -> None: + import pyarrow as pa + import pyarrow.parquet as pq + + episode_parquets = sorted((dataset_dir / "meta" / "episodes").rglob("*.parquet")) + if not episode_parquets: + return + + for episode_path in episode_parquets: + table = pq.read_table(episode_path) + to_timestamp = (n_frames - 1) / fps if fps > 0 else 0.0 + for video_key in video_keys: + prefix = f"videos/{video_key}" + if f"{prefix}/chunk_index" in table.column_names: + continue + n_rows = len(table) + table = table.append_column(f"{prefix}/chunk_index", pa.array([0] * n_rows, type=pa.int64())) + table = table.append_column(f"{prefix}/file_index", pa.array([0] * n_rows, type=pa.int64())) + table = table.append_column( + f"{prefix}/from_timestamp", pa.array([0.0] * n_rows, type=pa.float64()) + ) + table = table.append_column( + f"{prefix}/to_timestamp", pa.array([to_timestamp] * n_rows, type=pa.float64()) + ) + pq.write_table(table, episode_path) + + +def add_frames_from_recovery( + *, + dataset: Any, + recovery_rows: list[dict[str, Any]], + features: dict[str, Any], + images_dir: Path, + image_keys: list[str], + episode_name_by_key: dict[str, str], + task: str, +) -> int: + actual_frames = 0 + for frame_index, recovery_row in enumerate(recovery_rows): + frame = build_frame_dict( + recovery_row=recovery_row, + features=features, + images_dir=images_dir, + image_keys=image_keys, + episode_name_by_key=episode_name_by_key, + frame_index=frame_index, + task=task, + ) + if any(frame.get(key) is None for key in image_keys): + break + dataset.add_frame(frame) + actual_frames += 1 + return actual_frames + + +class DatasetRepairService: + def __init__(self, adapter: LeRobotDatasetAdapter | None = None) -> None: + self._adapter = adapter or LeRobotDatasetAdapter() + + def repair( + self, + diagnosis: DiagnosisResult, + *, + task: str, + vcodec: str, + dry_run: bool, + force: bool, + ) -> RepairResult: + dataset_dir = diagnosis.dataset_dir + damage = diagnosis.damage_type + + if damage == DamageType.HEALTHY: + return RepairResult(dataset_dir, damage, "healthy") + if damage == DamageType.EMPTY_SHELL: + return RepairResult(dataset_dir, damage, "skipped", error="empty shell -- nothing to recover") + if not diagnosis.repairable: + return RepairResult(dataset_dir, damage, "skipped", error="unrepairable") + if dry_run: + return RepairResult(dataset_dir, damage, "skipped", error="dry run") + + result = self._dispatch_repair(diagnosis, task=task, vcodec=vcodec, force=force) + if result.outcome != "repaired": + return result + + verify_dir = dataset_dir + if damage in {DamageType.CRASH_NO_SAVE, DamageType.TMP_VIDEOS_STUCK}: + verify_dir = dataset_dir.parent / f"{dataset_dir.name}_repaired" + verify_errors = verify_repaired_dataset(verify_dir) + if not verify_errors: + return result + return RepairResult(dataset_dir, damage, "failed", error="; ".join(verify_errors)) + + def _dispatch_repair( + self, + diagnosis: DiagnosisResult, + *, + task: str, + vcodec: str, + force: bool, + ) -> RepairResult: + dataset_dir = diagnosis.dataset_dir + damage = diagnosis.damage_type + if damage == DamageType.CRASH_NO_SAVE: + _, rebuilt = self._repair_crash_no_save(dataset_dir, diagnosis, task=task, vcodec=vcodec, force=force) + if not rebuilt: + return RepairResult(dataset_dir, damage, "skipped", error="_repaired already exists") + elif damage == DamageType.TMP_VIDEOS_STUCK: + _, rebuilt = self._repair_tmp_videos_stuck(dataset_dir, diagnosis, task=task, force=force) + if not rebuilt: + return RepairResult(dataset_dir, damage, "skipped", error="_repaired already exists") + elif damage == DamageType.PARQUET_NO_VIDEO: + self._repair_parquet_no_video(dataset_dir, vcodec=vcodec) + elif damage == DamageType.META_STALE: + self._repair_meta_stale(dataset_dir) + elif damage == DamageType.FRAME_MISMATCH: + self._repair_frame_mismatch(dataset_dir, diagnosis) + elif damage == DamageType.MISSING_CP: + self._repair_missing_cp(dataset_dir, diagnosis) + return RepairResult(dataset_dir, damage, "repaired") + + def _repair_crash_no_save( + self, + dataset_dir: Path, + diagnosis: DiagnosisResult, + *, + task: str, + vcodec: str, + force: bool, + ) -> tuple[Path, bool]: + out_dir, should_proceed = prepare_output_dir(dataset_dir, force) + if not should_proceed: + return out_dir, False + + info = load_info(dataset_dir) + recovery_rows = read_recovery_rows(dataset_dir) + features = normalize_feature_shapes(info["features"]) + image_keys = get_visual_keys(info) + images_dir = dataset_dir / "images" + n_usable = min(len(recovery_rows), diagnosis.details["min_images_per_camera"]) + if n_usable <= 0: + raise ValueError(f"No usable frames available to rebuild {dataset_dir}") + + dataset = self._adapter.create_dataset( + repo_id=f"local/{out_dir.name}", + fps=int(info["fps"]), + root=out_dir, + robot_type=info.get("robot_type"), + features=features, + use_videos=bool(image_keys), + vcodec=vcodec, + ) + episode_name_by_key = {key: get_single_episode_name(images_dir, key) for key in image_keys} + actual_frames = add_frames_from_recovery( + dataset=dataset, + recovery_rows=recovery_rows[:n_usable], + features=features, + images_dir=images_dir, + image_keys=image_keys, + episode_name_by_key=episode_name_by_key, + task=task, + ) + dataset.save_episode() + dataset.finalize() + copy_critical_phase_intervals(dataset_dir, out_dir, max_frames=actual_frames) + return out_dir, True + + def _repair_tmp_videos_stuck( + self, + dataset_dir: Path, + diagnosis: DiagnosisResult, + *, + task: str, + force: bool, + ) -> tuple[Path, bool]: + out_dir, should_proceed = prepare_output_dir(dataset_dir, force) + if not should_proceed: + return out_dir, False + + info = load_info(dataset_dir) + recovery_rows = read_recovery_rows(dataset_dir) + features = normalize_feature_shapes(info["features"]) + video_keys = get_video_keys(info) + tmp_videos: dict[str, Path] = diagnosis.details["tmp_videos"] + non_video_features = { + key: value for key, value in features.items() if value.get("dtype") not in {"video", "image"} + } + + dataset = self._adapter.create_dataset( + repo_id=f"local/{out_dir.name}", + fps=int(info["fps"]), + root=out_dir, + robot_type=info.get("robot_type"), + features=non_video_features, + use_videos=False, + vcodec="auto", + ) + for frame_index, recovery_row in enumerate(recovery_rows): + dataset.add_frame( + build_frame_dict( + recovery_row=recovery_row, + features=non_video_features, + images_dir=dataset_dir / "images", + image_keys=[], + episode_name_by_key={}, + frame_index=frame_index, + task=task, + ) + ) + dataset.save_episode() + dataset.finalize() + + for video_key in video_keys: + if video_key not in tmp_videos: + continue + src_mp4 = tmp_videos[video_key] + dst_mp4 = build_video_path(out_dir, info, video_key, episode_index=0) + dst_mp4.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src_mp4, dst_mp4) + + out_info = load_info(out_dir) + out_info["features"] = dict(info["features"]) + out_info["total_episodes"] = 1 + out_info["total_frames"] = len(recovery_rows) + out_info["video_path"] = info.get("video_path") or DEFAULT_VIDEO_PATH + write_info(out_dir, out_info) + patch_episodes_video_columns(out_dir, video_keys, len(recovery_rows), int(info["fps"])) + copy_critical_phase_intervals(dataset_dir, out_dir) + return out_dir, True + + def _repair_parquet_no_video(self, dataset_dir: Path, *, vcodec: str) -> None: + info = load_info(dataset_dir) + images_dir = dataset_dir / "images" + fps = int(info["fps"]) + for video_key in get_video_keys(info): + episode_dirs = list_episode_dirs(images_dir / video_key) + if not episode_dirs: + raise FileNotFoundError( + f"No PNG episode directories found for video key {video_key} in {dataset_dir}" + ) + for episode_dir in episode_dirs: + self._adapter.encode_video_frames( + frames_dir=episode_dir, + video_path=build_video_path(dataset_dir, info, video_key, parse_episode_index(episode_dir)), + fps=fps, + vcodec=vcodec, + ) + + def _patch_info_totals_from_parquet(self, dataset_dir: Path) -> tuple[int, int]: + info = load_info(dataset_dir) + _n_files, total_episodes, total_frames = scan_parquet_files(dataset_dir) + info["total_episodes"] = total_episodes + info["total_frames"] = total_frames + info["splits"] = {"train": f"0:{total_episodes}"} if total_episodes > 0 else {} + write_info(dataset_dir, info) + return total_episodes, total_frames + + def _repair_meta_stale(self, dataset_dir: Path) -> None: + self._patch_info_totals_from_parquet(dataset_dir) + + def _repair_missing_cp(self, dataset_dir: Path, diagnosis: DiagnosisResult) -> None: + cp_path = dataset_dir / "critical_phase_intervals.json" + cp_path.write_text( + json.dumps(diagnosis.details["log_cp_intervals"], indent=2) + "\n", + encoding="utf-8", + ) + + def _repair_frame_mismatch(self, dataset_dir: Path, diagnosis: DiagnosisResult) -> None: + n_keep = diagnosis.details["truncate_target_frames"] + if n_keep <= 0: + raise ValueError(f"No positive truncate target for {dataset_dir}") + self._truncate_recovery_jsonl(dataset_dir, n_keep) + self._truncate_images(dataset_dir, n_keep) + self._truncate_parquet(dataset_dir, n_keep) + if diagnosis.details["n_parquet_rows"] > 0: + self._patch_info_totals_from_parquet(dataset_dir) + + def _truncate_recovery_jsonl(self, dataset_dir: Path, n_keep: int) -> None: + recovery_path = dataset_dir / "recovery_frames.jsonl" + if not recovery_path.exists(): + return + kept_lines: list[str] = [] + with recovery_path.open() as handle: + for line in handle: + sanitized = sanitize_jsonl_line(line) + if not sanitized: + break + kept_lines.append(f"{sanitized}\n") + if len(kept_lines) >= n_keep: + break + recovery_path.write_text("".join(kept_lines), encoding="utf-8") + + def _truncate_images(self, dataset_dir: Path, n_keep: int) -> None: + for camera_dir in list_episode_dirs(dataset_dir / "images"): + seen = 0 + for episode_dir in list_episode_dirs(camera_dir): + for png_path in list_frame_pngs(episode_dir): + seen += 1 + if seen > n_keep: + png_path.unlink() + + def _truncate_parquet(self, dataset_dir: Path, n_keep: int) -> None: + import pyarrow.parquet as pq + + remaining = n_keep + for parquet_path in sorted((dataset_dir / "data").rglob("*.parquet")): + metadata = safe_read_parquet_metadata(parquet_path) + if metadata is None: + parquet_path.unlink() + continue + if remaining <= 0: + parquet_path.unlink() + continue + if metadata.num_rows <= remaining: + remaining -= metadata.num_rows + continue + table = safe_read_parquet_table(parquet_path) + if table is None: + parquet_path.unlink() + continue + pq.write_table(table.slice(0, remaining), parquet_path) + remaining = 0 + + +_REPAIR_SERVICE = DatasetRepairService() + + +def repair_dataset( + diagnosis: DiagnosisResult, + *, + task: str, + vcodec: str, + dry_run: bool, + force: bool, +) -> RepairResult: + return _REPAIR_SERVICE.repair( + diagnosis, + task=task, + vcodec=vcodec, + dry_run=dry_run, + force=force, + ) diff --git a/roboclaw/data/repair/types.py b/roboclaw/data/repair/types.py new file mode 100644 index 00000000..8a451f9b --- /dev/null +++ b/roboclaw/data/repair/types.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any + +SKIP_FRAME_KEYS = {"timestamp", "frame_index", "episode_index", "index", "task_index"} + + +class DamageType(Enum): + HEALTHY = "healthy" + EMPTY_SHELL = "empty_shell" + CRASH_NO_SAVE = "crash_no_save" + TMP_VIDEOS_STUCK = "tmp_videos_stuck" + PARQUET_NO_VIDEO = "parquet_no_video" + META_STALE = "meta_stale" + FRAME_MISMATCH = "frame_mismatch" + MISSING_CP = "missing_cp" + + +@dataclass(frozen=True) +class DiagnosisResult: + dataset_dir: Path + damage_type: DamageType + repairable: bool + details: dict[str, Any] + + +@dataclass(frozen=True) +class RepairResult: + dataset_dir: Path + damage_type: DamageType | None + outcome: str + error: str | None = None diff --git a/tests/test_dataset_boundary_export.py b/tests/test_dataset_boundary_export.py new file mode 100644 index 00000000..d7ee48b8 --- /dev/null +++ b/tests/test_dataset_boundary_export.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import csv +import json +from pathlib import Path +from types import SimpleNamespace + +import pyarrow as pa +import torch + +from roboclaw.data.repair.boundary import ( + BoundaryFrameExporter, + BoundaryFrameExportRequest, + export_episode_boundary_frames, + parse_episode_indices, + resolve_repo_id, + select_camera_key, +) + + +class _FakeDataset: + def __init__(self) -> None: + self.root = Path("/tmp/demo") + self.repo_id = "local/demo" + self.meta = SimpleNamespace( + episodes=pa.table( + { + "dataset_from_index": [0, 2], + "dataset_to_index": [2, 4], + "length": [2, 2], + "episode_success": ["success", "failure"], + } + ) + ) + self._frames = [ + {"observation.images.right_front": torch.zeros(3, 4, 4, dtype=torch.uint8)}, + {"observation.images.right_front": torch.ones(3, 4, 4, dtype=torch.uint8)}, + {"observation.images.right_front": torch.full((3, 4, 4), 2, dtype=torch.uint8)}, + {"observation.images.right_front": torch.full((3, 4, 4), 3, dtype=torch.uint8)}, + ] + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + return self._frames[index] + + +class _FakeAdapter: + def open_dataset(self, *, repo_id: str, root: Path) -> _FakeDataset: + dataset = _FakeDataset() + dataset.repo_id = repo_id + dataset.root = root + return dataset + + +class TestBoundaryHelpers: + def test_parse_episode_indices_supports_ranges_and_dedup(self) -> None: + assert parse_episode_indices("0-1,1,3", 5) == [0, 1, 3] + + def test_select_camera_key_prefers_front_then_largest(self) -> None: + info = { + "features": { + "observation.images.left_wrist": {"dtype": "video", "shape": [32, 32, 3], "names": None}, + "observation.images.right_front": {"dtype": "video", "shape": [16, 64, 3], "names": None}, + } + } + assert select_camera_key(info, None) == "observation.images.right_front" + + def test_resolve_repo_id_returns_relative_id_under_dataset_root(self, tmp_path: Path) -> None: + dataset_root = tmp_path / "local" / "demo" + dataset_root.mkdir(parents=True) + assert resolve_repo_id(dataset_root, tmp_path) == "local/demo" + + +class TestBoundaryExport: + def test_export_episode_boundary_frames_writes_manifest(self, tmp_path: Path) -> None: + manifest_path = export_episode_boundary_frames( + dataset=_FakeDataset(), + output_dir=tmp_path, + episode_indices=[0, 1], + camera_key="observation.images.right_front", + ) + + rows = list(csv.DictReader(manifest_path.open(encoding="utf-8"))) + assert len(rows) == 2 + assert rows[0]["episode_success"] == "success" + assert (tmp_path / "episode_000_first.png").exists() + assert (tmp_path / "episode_001_last.png").exists() + + def test_exporter_opens_dataset_and_exports(self, tmp_path: Path) -> None: + dataset_root = tmp_path / "datasets" / "local" / "demo" + meta_dir = dataset_root / "meta" + meta_dir.mkdir(parents=True) + (meta_dir / "info.json").write_text( + json.dumps( + { + "total_episodes": 2, + "features": { + "observation.images.right_front": { + "dtype": "video", + "shape": [4, 4, 3], + "names": None, + } + }, + } + ), + encoding="utf-8", + ) + + exporter = BoundaryFrameExporter(adapter=_FakeAdapter()) + result = exporter.export( + BoundaryFrameExportRequest( + dataset=str(dataset_root), + output_dir=tmp_path / "out", + episodes="0-1", + overwrite=False, + ) + ) + + assert result.episodes_exported == 2 + assert result.camera_key == "observation.images.right_front" + assert result.manifest_path.exists() diff --git a/tests/test_dataset_repair_core.py b/tests/test_dataset_repair_core.py new file mode 100644 index 00000000..9c887561 --- /dev/null +++ b/tests/test_dataset_repair_core.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +from PIL import Image + +from roboclaw.data.repair.repairers import DatasetRepairService +from roboclaw.data.repair.types import DamageType, DiagnosisResult + + +def _write_info(dataset_dir: Path, *, total_episodes: int = 0, total_frames: int = 0) -> None: + meta_dir = dataset_dir / "meta" + meta_dir.mkdir(parents=True, exist_ok=True) + info = { + "total_episodes": total_episodes, + "total_frames": total_frames, + "fps": 30, + "features": { + "observation.images.front": {"dtype": "video", "shape": [4, 4, 3], "names": None}, + "observation.state": {"dtype": "float32", "shape": [2], "names": None}, + "episode_index": {"dtype": "int64", "shape": [1], "names": None}, + }, + } + (meta_dir / "info.json").write_text(json.dumps(info), encoding="utf-8") + + +def _write_recovery(dataset_dir: Path, count: int) -> None: + rows = [json.dumps({"observation.state": [index, index + 1]}) for index in range(count)] + (dataset_dir / "recovery_frames.jsonl").write_text("\n".join(rows) + "\n", encoding="utf-8") + + +def _write_images(dataset_dir: Path, count: int) -> None: + image_dir = dataset_dir / "images" / "observation.images.front" / "episode-000000" + image_dir.mkdir(parents=True, exist_ok=True) + for index in range(count): + Image.new("RGB", (4, 4), (index, index, index)).save(image_dir / f"frame-{index:06d}.png") + + +def _write_parquet(dataset_dir: Path, episodes: list[int]) -> None: + data_dir = dataset_dir / "data" / "chunk-000" + data_dir.mkdir(parents=True, exist_ok=True) + table = pa.table( + { + "episode_index": episodes, + "observation.state": [[0.0, 1.0] for _ in episodes], + } + ) + pq.write_table(table, data_dir / "file-000.parquet") + + +def _write_video(dataset_dir: Path, episode_index: int = 0) -> None: + video_dir = dataset_dir / "videos" / "observation.images.front" / "chunk-000" + video_dir.mkdir(parents=True, exist_ok=True) + (video_dir / f"file-{episode_index:03d}.mp4").write_bytes(b"mp4") + + +class TestDatasetRepairCore: + def test_meta_stale_refreshes_info_totals(self, tmp_path: Path) -> None: + dataset_dir = tmp_path / "meta_stale" + _write_info(dataset_dir, total_episodes=0, total_frames=0) + _write_parquet(dataset_dir, [0, 0, 1]) + _write_video(dataset_dir, 0) + _write_video(dataset_dir, 1) + diagnosis = DiagnosisResult( + dataset_dir=dataset_dir, + damage_type=DamageType.META_STALE, + repairable=True, + details={"n_parquet_rows": 3}, + ) + + result = DatasetRepairService().repair( + diagnosis, + task="task", + vcodec="h264", + dry_run=False, + force=False, + ) + + info = json.loads((dataset_dir / "meta" / "info.json").read_text(encoding="utf-8")) + assert result.outcome == "repaired" + assert info["total_episodes"] == 2 + assert info["total_frames"] == 3 + assert info["splits"] == {"train": "0:2"} + + def test_missing_cp_writes_intervals(self, tmp_path: Path) -> None: + dataset_dir = tmp_path / "missing_cp" + _write_info(dataset_dir, total_episodes=1, total_frames=1) + _write_parquet(dataset_dir, [0]) + _write_video(dataset_dir) + diagnosis = DiagnosisResult( + dataset_dir=dataset_dir, + damage_type=DamageType.MISSING_CP, + repairable=True, + details={ + "n_parquet_rows": 1, + "log_cp_intervals": [{"episode_index": 0, "start_frame": 0, "end_frame": 1, "outcome": "success"}], + "log_path": dataset_dir.parent / "missing_cp.log", + }, + ) + + result = DatasetRepairService().repair( + diagnosis, + task="task", + vcodec="h264", + dry_run=False, + force=False, + ) + + intervals = json.loads((dataset_dir / "critical_phase_intervals.json").read_text(encoding="utf-8")) + assert result.outcome == "repaired" + assert intervals[0]["outcome"] == "success" + + def test_frame_mismatch_truncates_recovery_images_and_parquet(self, tmp_path: Path) -> None: + dataset_dir = tmp_path / "frame_mismatch" + _write_info(dataset_dir, total_episodes=1, total_frames=3) + _write_recovery(dataset_dir, 3) + _write_images(dataset_dir, 3) + _write_parquet(dataset_dir, [0, 0, 0]) + _write_video(dataset_dir) + diagnosis = DiagnosisResult( + dataset_dir=dataset_dir, + damage_type=DamageType.FRAME_MISMATCH, + repairable=True, + details={"truncate_target_frames": 2, "n_parquet_rows": 3}, + ) + + result = DatasetRepairService().repair( + diagnosis, + task="task", + vcodec="h264", + dry_run=False, + force=False, + ) + + recovery_lines = (dataset_dir / "recovery_frames.jsonl").read_text(encoding="utf-8").splitlines() + image_files = sorted((dataset_dir / "images").rglob("frame-*.png")) + table = pq.read_table(dataset_dir / "data" / "chunk-000" / "file-000.parquet") + info = json.loads((dataset_dir / "meta" / "info.json").read_text(encoding="utf-8")) + assert result.outcome == "repaired" + assert len(recovery_lines) == 2 + assert len(image_files) == 2 + assert table.num_rows == 2 + assert info["total_frames"] == 2 diff --git a/tests/test_dataset_repair_diagnosis.py b/tests/test_dataset_repair_diagnosis.py new file mode 100644 index 00000000..a27e7c35 --- /dev/null +++ b/tests/test_dataset_repair_diagnosis.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +from PIL import Image + +from roboclaw.data.repair.diagnosis import diagnose_dataset +from roboclaw.data.repair.types import DamageType + + +def _write_info(dataset_dir: Path, **overrides: object) -> None: + info = { + "total_episodes": 1, + "total_frames": 3, + "fps": 30, + "features": { + "observation.images.front": { + "dtype": "video", + "shape": [64, 64, 3], + "names": ["height", "width", "channel"], + }, + "observation.state": {"dtype": "float32", "shape": [2], "names": None}, + "episode_index": {"dtype": "int64", "shape": [1], "names": None}, + }, + } + info.update(overrides) + meta_dir = dataset_dir / "meta" + meta_dir.mkdir(parents=True, exist_ok=True) + (meta_dir / "info.json").write_text(json.dumps(info), encoding="utf-8") + + +def _write_recovery(dataset_dir: Path, count: int) -> None: + rows = [json.dumps({"observation.state": [float(index), float(index + 1)]}) for index in range(count)] + (dataset_dir / "recovery_frames.jsonl").write_text("\n".join(rows) + "\n", encoding="utf-8") + + +def _write_images(dataset_dir: Path, count: int, camera: str = "observation.images.front") -> None: + image_dir = dataset_dir / "images" / camera / "episode-000000" + image_dir.mkdir(parents=True, exist_ok=True) + for index in range(count): + Image.new("RGB", (8, 8), (index, index, index)).save(image_dir / f"frame-{index:06d}.png") + + +def _write_parquet(dataset_dir: Path, rows: int, episodes: list[int] | None = None) -> None: + if episodes is None: + episodes = [0] * rows + data_dir = dataset_dir / "data" / "chunk-000" + data_dir.mkdir(parents=True, exist_ok=True) + table = pa.table( + { + "episode_index": episodes, + "observation.state": [[0.0, 1.0] for _ in episodes], + } + ) + pq.write_table(table, data_dir / "file-000.parquet") + + +def _write_video(dataset_dir: Path, episode_index: int = 0, camera: str = "observation.images.front") -> None: + video_dir = dataset_dir / "videos" / camera / "chunk-000" + video_dir.mkdir(parents=True, exist_ok=True) + (video_dir / f"file-{episode_index:03d}.mp4").write_bytes(b"mp4") + + +class TestDatasetDiagnosis: + def test_tmp_videos_stuck_wins_before_crash_no_save(self, tmp_path: Path) -> None: + dataset_dir = tmp_path / "tmp_stuck" + _write_info(dataset_dir, total_episodes=0, total_frames=0) + _write_recovery(dataset_dir, 2) + tmp_dir = dataset_dir / "tmpabc" + tmp_dir.mkdir(parents=True) + (tmp_dir / "observation.images.front_000.mp4").write_bytes(b"mp4") + + diagnosis = diagnose_dataset(dataset_dir) + + assert diagnosis.damage_type is DamageType.TMP_VIDEOS_STUCK + assert diagnosis.repairable is True + + def test_missing_cp_detected_from_log_when_data_present(self, tmp_path: Path) -> None: + dataset_dir = tmp_path / "missing_cp" + _write_info(dataset_dir) + _write_parquet(dataset_dir, 3) + _write_images(dataset_dir, 3) + _write_video(dataset_dir) + (dataset_dir.parent / "missing_cp.log").write_text( + "[CP] END at episode 0, frame 2 (segment: 1-2, 1 frames, outcome=success)\n", + encoding="utf-8", + ) + + diagnosis = diagnose_dataset(dataset_dir) + + assert diagnosis.damage_type is DamageType.MISSING_CP + assert diagnosis.details["n_log_cp"] == 1 + + def test_frame_mismatch_uses_smallest_available_count(self, tmp_path: Path) -> None: + dataset_dir = tmp_path / "frame_mismatch" + _write_info(dataset_dir, total_frames=4) + _write_recovery(dataset_dir, 4) + _write_images(dataset_dir, 3) + _write_parquet(dataset_dir, 2) + _write_video(dataset_dir) + + diagnosis = diagnose_dataset(dataset_dir) + + assert diagnosis.damage_type is DamageType.FRAME_MISMATCH + assert diagnosis.details["truncate_target_frames"] == 2