diff --git a/.gitignore b/.gitignore index 5414c71..f5c968b 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,6 @@ Thumbs.db .env venv/ .venv/ + +# Git worktrees +.worktrees/ diff --git a/configs/homography/.gitkeep b/configs/homography/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/configs/models.py b/configs/models.py new file mode 100644 index 0000000..4ce278d --- /dev/null +++ b/configs/models.py @@ -0,0 +1,22 @@ +"""Centralized model path configuration.""" + +from configs.paths import PROJECT_ROOT + +RUNS_DIR = PROJECT_ROOT / "runs" / "detect" + +# Current model (v2) +V2_MODEL = RUNS_DIR / "frisbee_det_s_v2" / "weights" / "best.pt" + +# Legacy model (v1) +V1_MODEL = RUNS_DIR / "frisbee_det_s" / "weights" / "best.pt" + +# v3 model (cleaned data, box=5) +V3_MODEL = RUNS_DIR / "frisbee_det_s_v3" / "weights" / "best.pt" + +DEFAULT_MODEL = V3_MODEL +DEFAULT_MODEL_SIZE = "s" +DEFAULT_IMGSZ = 1280 +DEFAULT_CONF = 0.35 +DEFAULT_EPOCHS = 100 +DEFAULT_BATCH = 2 +SEED = 42 diff --git a/docs/superpowers/plans/2026-05-13-v5-precision-improvement.md b/docs/superpowers/plans/2026-05-13-v5-precision-improvement.md new file mode 100644 index 0000000..11648a8 --- /dev/null +++ b/docs/superpowers/plans/2026-05-13-v5-precision-improvement.md @@ -0,0 +1,296 @@ +# v5 Precision Improvement — 实现计划 + +> **面向 AI 代理的工作者:** 必需子技能:使用 superpowers:subagent-driven-development(推荐)或 superpowers:executing-plans 逐任务实现此计划。步骤使用复选框(`- [ ]`)语法来跟踪进度。 + +**目标:** 在 `models/train.py` 中添加 `--cls` 参数支持,用 `cls=1.3` 重新训练 v5 模型,将误检率从 ~60% 降至 <15%。 + +**架构:** 单文件代码变更 + 4 个运维步骤。先在 `train.py` 中添加 `--cls` argparse + 函数参数,然后在 tmux 中训练 v5,最后在验证集和两个测试视频上评估。 + +**技术栈:** Python 3.10, ultralytics (YOLOv8), PyTorch, SAHI, OpenCV + +--- + +## 文件变更 + +| 文件 | 变更 | 职责 | +|------|------|------| +| `models/train.py` | 修改: 19-76 | 添加 `--cls` argparse 参数,传入 `train_frisbee_detector()` | + +无新增文件。无测试文件(代码变更极简,属于训练脚本的参数传递)。 + +--- + +### 任务 1:添加 `--cls` 参数到 train.py + +**文件:** +- 修改:`models/train.py:19-76`(`train_frisbee_detector` 函数签名 + argparse) + +**分析:** 当前 `train_frisbee_detector()` 函数签名已有 `box` 参数但无 `cls`。argparse 也无 `--cls` 参数。需要同时添加两者并用 kwargs 方式传入 `model.train()`。 + +当前 `model.train()` 调用没有显式传 `cls`,因此 YOLO 使用默认值 0.5。v4 的 args.yaml 证实了这一点。 + +添加 `cls` 参数到函数签名,默认同 Ultralytics 默认值 0.5,以便 `--validate-only` 等非训练路径不受影响。 + +- [ ] **步骤 1:读取当前 train.py** + +```bash +cat models/train.py +``` + +确认当前代码布局。 + +- [ ] **步骤 2:修改函数签名** + +在 `train_frisbee_detector()` 的 `box: float = 7.5` 参数后添加 `cls: float = 0.5`: + +``` +box: float = 7.5, +cls: float = 0.5, +``` + +- [ ] **步骤 3:在 argparse 中添加 `--cls`** + +在 `--box` 参数块后添加: + +```python +parser.add_argument("--cls", type=float, default=0.5, help="Classification loss weight") +``` + +- [ ] **步骤 4:将 cls 传入函数调用** + +在 `train_frisbee_detector()` 调用位置添加 `cls=args.cls` + +- [ ] **步骤 5:将 cls 传入 model.train()** + +`cls` 参数通过 kwargs 自动传入 `model.train()`,确保已有的调用链将其传入: + +```python +results = model.train( + ... + box=box, + cls=cls, + ... +) +``` + +- [ ] **步骤 6:测试 `--help` 确认参数暴露** + +```bash +python3 models/train.py --help +``` + +确认输出包含 `--cls`。 + +- [ ] **步骤 7:Commit** + +```bash +git add models/train.py +git commit -m "feat: add --cls argument to train.py for classification loss weight" +``` + +--- + +### 任务 2:训练 v5 模型 + +**文件:** +- 无代码变更。纯运维操作。 + +**重要约束:** YOLO 训练 >10min,必须在 tmux 中运行。Bash 工具有 10 分钟超时,训练需 1.5-2.5h。 + +**注意双嵌套 bug:** `project="runs/detect"` 导致输出到 `runs/detect/runs/detect/frisbee_det_s_v5/`。训练结束后必须移出。 + +- [ ] **步骤 1:启动 tmux 训练会话** + +```bash +tmux new-session -d -s train -c /mnt/e/frisbee-detector +tmux send-keys -t train "python3 models/train.py \ + --data configs/frisbee_merged.yaml \ + --model-size s \ + --name frisbee_det_s_v5 \ + --epochs 100 --imgsz 1280 --batch 2 \ + --patience 20 --box 5 --cls 1.3 --close-mosaic 10" Enter +``` + +确认已启动: + +```bash +tmux ls +``` + +- [ ] **步骤 2:确认训练启动状态** + +```bash +sleep 30 && tail -5 runs/detect/runs/detect/frisbee_det_s_v5/results.csv 2>/dev/null || echo "仍在初始化..." +``` + +- [ ] **步骤 3:修复双嵌套路径** + +训练结束后(从 results.csv 确认 epoch 数不再增加或 tmux 会话退出): + +```bash +tmux capture-pane -t train -p | tail -20 +``` + +修复路径: + +```bash +mv runs/detect/runs/detect/frisbee_det_s_v5 runs/detect/frisbee_det_s_v5 +``` + +- [ ] **步骤 4:确认模型文件存在** + +```bash +ls -lh runs/detect/frisbee_det_s_v5/weights/best.pt +``` + +预期:~22MB 文件存在。 + +--- + +### 任务 3:验证集评估 + +**文件:** +- 无代码变更。纯运维操作。 + +- [ ] **步骤 1:运行验证** + +```bash +python3 models/train.py --validate-only \ + --model-path runs/detect/frisbee_det_s_v5/weights/best.pt \ + --data configs/frisbee_merged.yaml +``` + +预期输出 mAP50, mAP50-95, Precision, Recall。 + +- [ ] **步骤 2:记录并与 v4 对比** + +| 指标 | v4 | v5 | 变化 | +|------|:--:|:--:|:----:| +| mAP50 | 0.815 | ? | ? | +| Precision | 0.847 | ? | ? | +| Recall | 0.706 | ? | ? | + +--- + +### 任务 4:测试视频评估 + +**文件:** +- 无代码变更。纯运维操作。 + +- [ ] **步骤 1:评估 55-56min 测试片段** + +```bash +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v5/weights/best.pt \ + --video movie/25866279684-1-192_55-56min.mp4 \ + --conf 0.20 +``` + +关注:帧检测率(目标 60-70%)、平均检测/帧(目标 1.0-1.3)、置信度分布。 + +- [ ] **步骤 2:评估 20-23min 测试片段** + +```bash +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v5/weights/best.pt \ + --video movie/clip_20-23min.mp4 \ + --conf 0.20 +``` + +同样关注三个指标。 + +- [ ] **步骤 3:FP 目测抽查** + +各抽 50 帧(共 100 帧),目测每帧的检测框准确性。 + +使用以下方法输出中间帧的标注结果: + +```bash +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v5/weights/best.pt \ + --video movie/25866279684-1-192_55-56min.mp4 \ + --conf 0.20 +# 从 runs/detect/eval/ 目录看保存的标注结果 +``` + +--- + +### 任务 5:OpenImages 下载(如果 v5 未达标,启动此备份计划) + +**文件:** +- 仅当 v5 FP 率 >15% 或 dets/frame >1.3 时才执行 + +- [ ] **步骤 1:收集 Flying disc ID** + +```bash +cat /mnt/e/firsbee/03_datasets/openimages_frisbee/flying_disc_all_ids.txt \ + /mnt/e/firsbee/03_datasets/openimages_frisbee/disc_golf_all_ids.txt \ + /mnt/e/firsbee/03_datasets/openimages_frisbee/frisbee_games_all_ids.txt \ + | sort -u > /tmp/all_frisbee_ids.txt +wc -l /tmp/all_frisbee_ids.txt +``` + +预期:约 865 行(去重后)。 + +- [ ] **步骤 2:尝试 CVDF 镜像下载** + +从 OpenImages v6 中 Flying disc 类的 ID 下载图片。使用 20 线程并发。 + +安装下载工具: + +```bash +pip install --break-system-packages oidv6 +``` + +或使用 wget/curl 按 ID 构造 URL: + +``` +URL_TEMPLATE="https://storage.googleapis.com/openimages/2018_04/train/train_%s_%02d.jpg" +``` + +如果 Google Storage 返回 403,尝试 CVDF 镜像: + +``` +URL_TEMPLATE="https://storage.cvdfoundation.org/openimages/2018_04/train/train_%s_%02d.jpg" +``` + +- [ ] **步骤 3:验证下载量并记录到 wiki** + +```bash +ls /mnt/e/firsbee/03_datasets/openimages_frisbee/images/*.jpg | wc -l +``` + +抽检 20 张确认标注质量。将下载结果记录到 wiki: + +``` +新条目:raw/articles/openimages-download-result.md +更新:concepts/frisbee-recognition-project.md 中 OpenImages 小节 +``` + +- [ ] **步骤 4:添加至合并数据集** + +修改 `tools/merge_datasets.py` 的 SOURCES 列表: + +```python +{"name": "openimages", "splits": ["train", "val"]}, +``` + +重新运行合并: + +```bash +python3 tools/merge_datasets.py +``` + +- [ ] **步骤 5:训练 v6(与 v5 相同参数)** + +```bash +tmux new-session -d -s train -c /mnt/e/frisbee-detector +tmux send-keys -t train "python3 models/train.py \ + --data configs/frisbee_merged.yaml \ + --model-size s \ + --name frisbee_det_s_v6 \ + --epochs 100 --imgsz 1280 --batch 2 \ + --patience 20 --box 5 --cls 1.3 --close-mosaic 10" Enter +``` + +--- diff --git a/docs/superpowers/plans/2026-05-14-siglip-frisbee-classifier.md b/docs/superpowers/plans/2026-05-14-siglip-frisbee-classifier.md new file mode 100644 index 0000000..ccfd8a0 --- /dev/null +++ b/docs/superpowers/plans/2026-05-14-siglip-frisbee-classifier.md @@ -0,0 +1,431 @@ +# SigLIP 零样本飞盘分类 实现计划 + +> **面向 AI 代理的工作者:** 必需子技能:使用 superpowers:subagent-driven-development(推荐)或 superpowers:executing-plans 逐任务实现此计划。步骤使用复选框(`- [ ]`)语法来跟踪进度。 + +**目标:** 用 SigLIP(Google)零样本分类器过滤 YOLO 检测裁剪图,区分飞盘和白帽子/人头/反光物体等假阳性,产出一个可重用的命令行工具和验证脚本。 + +**架构:** `tools/classify_frisbee.py`(分类器)+ `tools/validate_classifier.py`(验证对比脚本)。分类器加载 `google/siglip-so400m-patch14-384`,用 "a photo of a frisbee" prompt 做零样本分类,支持单图推理和目录批量处理,输出 CSV。验证脚本对比 SigLIP 结果与人工标注,输出混淆矩阵和准确率。 + +**技术栈:** PyTorch 2.11 + Transformers (HuggingFace) + SigLIP (google/siglip-so400m-patch14-384, ~3.5GB, 首次运行自动下载) + +--- + +## 文件结构 + +| 文件 | 职责 | +|------|------| +| `tools/classify_frisbee.py` | SigLIP 零样本分类器:模型加载、单图分类、目录批量处理、argparse 入口 | +| `tools/validate_classifier.py` | 验证脚本:对比分类器结果与人工标注 CSV,输出准确率/FP/FN 统计 | +| `tests/test_classify_frisbee.py` | 测试:模型前向传播形状、单图分类输出格式、目录批量 CSV 正确性 | + +--- + +### 任务 1:注册 pytest 标记 + 创建测试文件 + +**文件:** +- 修改:`pytest.ini` +- 创建:`tests/test_classify_frisbee.py` + +- [ ] **步骤 1:注册 `slow` 标记** + +在 `pytest.ini` 的 `markers` 下追加 `slow`: + +```ini +[pytest] +markers = + integration: marks tests that require real data files + slow: marks tests that require GPU and model download +``` + +- [ ] **步骤 2:编写失败的测试** + +创建 `tests/test_classify_frisbee.py`: + +```python +"""Tests for SigLIP zero-shot frisbee classifier.""" +import os +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from pathlib import Path +import pytest +import csv + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent + + +@pytest.mark.slow +def test_model_forward_pass_output_shape(): + """SigLIP 模型前向传播输出形状为 (1, 1).""" + from tools.classify_frisbee import load_model_and_processor + from PIL import Image + import torch + + model, processor = load_model_and_processor(device="cpu") + dummy_img = Image.new("RGB", (384, 384), color=(128, 128, 128)) + inputs = processor(text=["a photo of a frisbee"], images=dummy_img, + padding="max_length", return_tensors="pt") + + with torch.no_grad(): + outputs = model(**inputs) + + assert outputs.logits_per_image.shape == (1, 1) + + +def test_classify_image_returns_tuple(): + """单图分类返回 (is_frisbee: bool, probability: float).""" + from tools.classify_frisbee import load_model_and_processor, classify_image + from PIL import Image + import tempfile + + model, processor = load_model_and_processor(device="cpu") + + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f: + dummy = Image.new("RGB", (100, 100), color=(0, 0, 0)) + dummy.save(f.name, format="JPEG") + is_frisbee, prob = classify_image(f.name, model, processor, device="cpu") + os.unlink(f.name) + + assert isinstance(is_frisbee, bool) + assert isinstance(prob, float) + assert 0.0 <= prob <= 1.0 + + +def test_classify_directory_writes_csv(tmp_path): + """批量分类目录输出 CSV,包含所有文件.""" + from tools.classify_frisbee import load_model_and_processor, classify_directory + from PIL import Image + + model, processor = load_model_and_processor(device="cpu") + + for i in range(3): + img = Image.new("RGB", (100, 100), color=(i * 50, i * 50, i * 50)) + img.save(str(tmp_path / f"crop_{i}_c0.50_vid.jpg"), format="JPEG") + + output_csv = str(tmp_path / "results.csv") + classify_directory(model, processor, str(tmp_path), output_csv, device="cpu") + + with open(output_csv) as f: + reader = csv.DictReader(f) + rows = list(reader) + assert len(rows) == 3 + for row in rows: + assert "filename" in row + assert "label" in row + assert "confidence" in row + assert row["label"] in ("frisbee", "not_frisbee") + conf = float(row["confidence"]) + assert 0.0 <= conf <= 1.0 +``` + +- [ ] **步骤 3:运行测试验证失败** + +```bash +python3 -m pytest tests/test_classify_frisbee.py::test_classify_image_returns_tuple -v +``` +预期:`ModuleNotFoundError: No module named 'tools.classify_frisbee'` + +- [ ] **步骤 4:编写最少实现** + +创建 `tools/classify_frisbee.py`: + +```python +"""SigLIP zero-shot frisbee classifier for YOLO detection crops.""" +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import argparse +import csv +import torch +from PIL import Image + + +def load_model_and_processor(device="cuda"): + """Load SigLIP model and processor. First call downloads ~3.5GB.""" + from transformers import AutoModel, AutoProcessor + model_name = "google/siglip-so400m-patch14-384" + model = AutoModel.from_pretrained(model_name).to(device) + processor = AutoProcessor.from_pretrained(model_name) + model.eval() + return model, processor + + +def classify_image(image_path, model, processor, device="cuda", threshold=0.5): + """Classify a single crop. Returns (is_frisbee, probability).""" + prompt = "a photo of a frisbee or flying disc" + image = Image.open(str(image_path)).convert("RGB") + + inputs = processor(text=[prompt], images=image, + padding="max_length", return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits_per_image + prob = torch.sigmoid(logits)[0][0].item() + + return prob >= threshold, prob + + +def classify_directory(model, processor, img_dir, output_csv, device="cuda", threshold=0.5): + """Classify all images in a directory, write results to CSV.""" + img_dir = Path(img_dir) + results = [] + + image_files = sorted( + p for p in img_dir.iterdir() + if p.suffix.lower() in (".jpg", ".jpeg", ".png") and p.is_file() + ) + + for img_path in image_files: + is_frisbee, prob = classify_image(str(img_path), model, processor, device=device, threshold=threshold) + label = "frisbee" if is_frisbee else "not_frisbee" + results.append((img_path.name, label, prob)) + + with open(output_csv, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["filename", "label", "confidence"]) + for filename, label, prob in results: + writer.writerow([filename, label, f"{prob:.4f}"]) + + return len(results) + + +def main(): + parser = argparse.ArgumentParser(description="SigLIP zero-shot frisbee classifier") + parser.add_argument("--crop-dir", required=True, help="Directory with crop images") + parser.add_argument("--output", default=None, help="Output CSV path") + parser.add_argument("--threshold", type=float, default=0.5, help="Classification threshold (default: 0.5)") + parser.add_argument("--device", default="cuda") + args = parser.parse_args() + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + print("Loading SigLIP model (first time downloads ~3.5GB)...") + model, processor = load_model_and_processor(device=device) + + crop_dir = Path(args.crop_dir) + output = Path(args.output) if args.output else crop_dir / "siglip_results.csv" + + n = classify_directory(model, processor, str(crop_dir), str(output), device=device, threshold=args.threshold) + print(f"Classified {n} images") + print(f"Results saved to {output}") + + if output.exists(): + frisbee = 0 + not_frisbee = 0 + with open(output) as f: + for row in csv.DictReader(f): + if row["label"] == "frisbee": + frisbee += 1 + else: + not_frisbee += 1 + total = frisbee + not_frisbee + if total > 0: + print(f"frisbee={frisbee} not_frisbee={not_frisbee} " + f"detection_rate={frisbee/total*100:.1f}%") + + +if __name__ == "__main__": + main() +``` + +- [ ] **步骤 5:运行测试验证通过** + +```bash +python3 -m pytest tests/test_classify_frisbee.py -v -m "not slow" +``` +预期:2 PASSED, 1 deselected + +- [ ] **步骤 6:运行慢速测试(首次需下载 ~3.5GB)** + +```bash +python3 -m pytest tests/test_classify_frisbee.py::test_model_forward_pass_output_shape -v +``` +预期:1 PASSED + +- [ ] **步骤 7:Commit** + +```bash +git add pytest.ini tests/test_classify_frisbee.py tools/classify_frisbee.py +git commit -m "feat: add SigLIP zero-shot frisbee classifier" +``` + +--- + +### 任务 2:验证对比脚本 + +**文件:** +- 创建:`tools/validate_classifier.py` + +- [ ] **步骤 1:创建验证脚本** + +创建 `tools/validate_classifier.py`: + +```python +"""Compare classifier results against human labels, output confusion matrix.""" +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import argparse +import csv + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--human-csv", required=True, help="Human labels CSV (review_results.csv)") + parser.add_argument("--classifier-csv", required=True, help="Classifier output CSV") + parser.add_argument("--human-label-col", default="result", help="Column name for human label") + parser.add_argument("--human-positive", default="TP", help="Positive label value in human CSV") + parser.add_argument("--classifier-label-col", default="label", help="Column name for classifier label") + parser.add_argument("--classifier-positive", default="frisbee", help="Positive label value in classifier CSV") + args = parser.parse_args() + + human = {} + with open(args.human_csv) as f: + for row in csv.DictReader(f): + human[row["filename"]] = row[args.human_label_col].strip().upper() == args.human_positive.upper() + + classifier = {} + with open(args.classifier_csv) as f: + for row in csv.DictReader(f): + classifier[row["filename"]] = row[args.classifier_label_col] == args.classifier_positive + + tp = fp = fn = tn = 0 + for fname, is_positive in human.items(): + if fname not in classifier: + continue + pred = classifier[fname] + if is_positive and pred: + tp += 1 + elif not is_positive and pred: + fp += 1 + elif is_positive and not pred: + fn += 1 + else: + tn += 1 + + total = tp + fp + fn + tn + accuracy = (tp + tn) / max(total, 1) * 100 + precision = tp / max(tp + fp, 1) * 100 + recall = tp / max(tp + fn, 1) * 100 + f1 = 2 * precision * recall / max(precision + recall, 1) + + print(f"=== Classification Report ===") + print(f"Total: {total}") + print(f"TP (correct frisbee): {tp}") + print(f"FP (false frisbee): {fp}") + print(f"FN (missed frisbee): {fn}") + print(f"TN (correct non-frisbee): {tn}") + print(f"Accuracy: {accuracy:.1f}%") + print(f"Precision: {precision:.1f}%") + print(f"Recall: {recall:.1f}%") + print(f"F1: {f1:.1f}%") + + +if __name__ == "__main__": + main() +``` + +- [ ] **步骤 2:Commit** + +```bash +git add tools/validate_classifier.py +git commit -m "feat: add classifier validation script" +``` + +--- + +### 任务 3:在 200 张标注裁剪图上验证 + +- [ ] **步骤 1:运行 SigLIP 分类器** + +```bash +python3 tools/classify_frisbee.py --crop-dir data/perbox_crops --device cuda +``` +预期:输出 `data/perbox_crops/siglip_results.csv`,显示 TP/FP 统计 + +- [ ] **步骤 2:对比 SigLIP 结果与人工标注** + +```bash +python3 tools/validate_classifier.py \ + --human-csv data/perbox_crops/review_results.csv \ + --classifier-csv data/perbox_crops/siglip_results.csv +``` +预期:输出混淆矩阵和 Precision/Recall/F1 + +- [ ] **步骤 3:决策门** + +根据输出决定: + +| SigLIP F1 | 结论 | +|:---------:|------| +| > 85% | ✅ 零样本方案有效,进入任务 4 | +| 70-85% | ⚠️ 尝试调 `--threshold`(0.3, 0.4, 0.6, 0.7)重新跑,选最优 F1 | +| < 70% | ❌ SigLIP 效果不足,报告失败 | + +如果需调阈值,重新运行: + +```bash +python3 tools/classify_frisbee.py --crop-dir data/perbox_crops --threshold 0.3 --device cuda +python3 tools/validate_classifier.py --human-csv data/perbox_crops/review_results.csv --classifier-csv data/perbox_crops/siglip_results.csv +``` + +- [ ] **步骤 4:记录结果** + +```bash +git add -A +git commit -m "results: SigLIP validation on 200 labeled crops" +``` + +--- + +### 任务 4:过滤 v7_quick 全部检测裁剪图(仅任务 3 通过后执行) + +- [ ] **步骤 1:分类 55-56min 裁剪图** + +```bash +python3 tools/classify_frisbee.py --crop-dir data/fp_spotcheck_v7_55min --threshold --device cuda +``` + +- [ ] **步骤 2:分类 20-23min 裁剪图** + +```bash +python3 tools/classify_frisbee.py --crop-dir data/fp_spotcheck_v7_20min --threshold --device cuda +``` + +- [ ] **步骤 3:汇总两个视频的过滤结果** + +查看两个 `siglip_results.csv` 的 frisbee/not_frisbee 统计,估算 v7_quick 的真实 FP 率。 + +- [ ] **步骤 4:记录结果并决策** + +```bash +git add -A +git commit -m "results: SigLIP filtering on v7_quick 4228 crops" +``` + +--- + +## 自检 + +**1. 规格覆盖度:** +- [x] SigLIP 模型加载 + 前向传播测试 → 任务 1 (`test_model_forward_pass_output_shape`) +- [x] 单图分类返回 `(bool, float)` → 任务 1 (`test_classify_image_returns_tuple`) +- [x] 目录批量分类输出 CSV → 任务 1 (`test_classify_directory_writes_csv`) +- [x] argparse 入口(`--crop-dir`, `--output`, `--threshold`, `--device`)→ 任务 1 +- [x] 阈值可调 → 任务 1 `--threshold` 参数 +- [x] 验证对比脚本(混淆矩阵 + Precision/Recall/F1)→ 任务 2 +- [x] 在 200 张已知数据上验证准确率 → 任务 3 +- [x] 阈值调优流程 → 任务 3 步骤 3 +- [x] 过滤 v7_quick 全量裁剪图 → 任务 4 + +**2. 占位符扫描:** 无 TODO、无 "待定"。任务 4 中的 `` 由任务 3 确定后填入,是唯一的动态值,已在步骤中注明。 + +**3. 类型一致性:** +- `load_model_and_processor` → 返回 `(model, processor)` → 所有函数签名接收 `(model, processor, ...)` +- `classify_image` → 返回 `(bool, float)` → `classify_directory` 解构为 `is_frisbee, prob` +- CSV 列名:`filename`, `label`, `confidence` → `validate_classifier.py` 使用 `--classifier-label-col` 和 `--classifier-positive` 参数匹配 +- 标签值:`"frisbee"` / `"not_frisbee"` → 测试和实现中一致 diff --git a/docs/superpowers/plans/2026-05-14-v7-classifier.md b/docs/superpowers/plans/2026-05-14-v7-classifier.md new file mode 100644 index 0000000..8a2a314 --- /dev/null +++ b/docs/superpowers/plans/2026-05-14-v7-classifier.md @@ -0,0 +1,878 @@ +# v7 二分类器 + 硬负样本扩充 实现计划 + +> **面向 AI 代理的工作者:** 必需子技能:使用 superpowers:subagent-driven-development(推荐)或 superpowers:executing-plans 逐任务实现此计划。步骤使用复选框(`- [ ]`)语法来跟踪进度。 + +**目标:** 用200张已标注裁剪图(22 TP + 178 FP)训练ResNet18二分类器,自动分类v7_quick的4228个检测,获取真实FP率并收集新硬负样本用于v8训练。 + +**架构:** `tools/train_frisbee_classifier.py`(训练脚本)→ 产出模型 → `tools/classify_crops.py`(推理脚本)→ 产出各裁剪图的 TP/FP 标签。两者共享 `classifier_utils.py` 中的通用函数(数据加载、模型构建、图像变换)。 + +**技术栈:** PyTorch 2.11 + torchvision (ResNet18)、OpenCV(图像读取)、pytest + +--- + +## 文件结构 + +| 文件 | 职责 | +|------|------| +| `tests/test_classifier.py` | 测试:数据加载、模型构造、训练冒烟测试、单图分类、批量分类 | +| `tools/classifier_utils.py` | 共享模块:`load_labeled_dataset()`、`create_model()`、`get_transforms()` | +| `tools/train_frisbee_classifier.py` | 训练入口:加载数据 → 训练 → 验证 → 保存模型 | +| `tools/classify_crops.py` | 推理入口:加载模型 → 批量分类裁剪图目录 → 输出CSV | + +--- + +### 任务 1:创建测试文件 + 数据加载测试 + +**文件:** +- 创建:`tests/test_classifier.py` +- 创建:`tools/classifier_utils.py` + +- [ ] **步骤 1:编写失败的测试(数据加载)** + +```python +"""Tests for frisbee binary classifier.""" +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from pathlib import Path + +import pytest +import numpy as np + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +REVIEW_CSV = PROJECT_ROOT / "data" / "perbox_crops" / "review_results.csv" +CROP_DIR = PROJECT_ROOT / "data" / "perbox_crops" + + +def test_load_labeled_dataset(): + """加载已标注数据集,返回 (路径, 标签) 列表,0=FP, 1=TP.""" + from tools.classifier_utils import load_labeled_dataset + + samples = load_labeled_dataset(REVIEW_CSV, CROP_DIR) + + assert len(samples) == 200 + paths, labels = zip(*samples) + assert all(p.is_file() for p in paths) + assert set(labels) == {0, 1} + + tp_count = sum(labels) + fp_count = len(labels) - tp_count + assert tp_count == 22 + assert fp_count == 178 +``` + +```python +def test_train_val_split(): + """分层划分训练/验证集,保留类别比例.""" + from tools.classifier_utils import split_train_val + + import tempfile, csv + + # Create a minimal temporary dataset + with tempfile.TemporaryDirectory() as tmpdir: + tmp = Path(tmpdir) + # Create 5 fake images and a CSV + for i in range(5): + import cv2 + img = np.zeros((100, 100, 3), dtype=np.uint8) + cv2.imwrite(str(tmp / f"crop_{i}.jpg"), img) + + csv_path = tmp / "review.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["filename", "result", "frame", "conf"]) + writer.writerow(["crop_0.jpg", "TP", "1", "0.5"]) + writer.writerow(["crop_1.jpg", "FP", "2", "0.4"]) + writer.writerow(["crop_2.jpg", "TP", "3", "0.6"]) + writer.writerow(["crop_3.jpg", "FP", "4", "0.3"]) + writer.writerow(["crop_4.jpg", "TP", "5", "0.7"]) + + samples = [(tmp / row[0], 1 if row[1] == "TP" else 0) for row in [ + ("crop_0.jpg", "TP"), ("crop_1.jpg", "FP"), ("crop_2.jpg", "TP"), + ("crop_3.jpg", "FP"), ("crop_4.jpg", "TP") + ]] + + train, val = split_train_val(samples, val_ratio=0.4, seed=42) + + assert len(train) + len(val) == 5 + train_tp = sum(1 for _, l in train if l == 1) + train_fp = sum(1 for _, l in train if l == 0) + val_tp = sum(1 for _, l in val if l == 1) + val_fp = sum(1 for _, l in val if l == 0) + + assert train_tp + val_tp == 3 + assert train_fp + val_fp == 2 + assert train_tp >= 1 and val_tp >= 1 + assert train_fp >= 1 and val_fp >= 1 +``` + +- [ ] **步骤 2:运行测试验证失败** + +```bash +python3 -m pytest tests/test_classifier.py::test_load_labeled_dataset -v +``` +预期:`ModuleNotFoundError: No module named 'tools.classifier_utils'` + +- [ ] **步骤 3:编写最小实现** + +创建 `tools/classifier_utils.py`: + +```python +"""Shared utilities for frisbee classifier.""" +import csv +from pathlib import Path + + +def load_labeled_dataset(csv_path, img_dir): + """Load labeled dataset from CSV and image directory. + + Returns: + list of (Path, int): (image_path, label) where label is 0=FP, 1=TP. + """ + samples = [] + img_dir = Path(img_dir) + with open(csv_path) as f: + for row in csv.DictReader(f): + label = 1 if row["result"].strip().upper() == "TP" else 0 + img_path = img_dir / row["filename"] + if img_path.is_file(): + samples.append((img_path, label)) + return samples + + +def split_train_val(samples, val_ratio=0.2, seed=42): + """Stratified split into train and validation sets.""" + import random + random.seed(seed) + + tp_samples = [(p, l) for p, l in samples if l == 1] + fp_samples = [(p, l) for p, l in samples if l == 0] + + random.shuffle(tp_samples) + random.shuffle(fp_samples) + + tp_split = int(len(tp_samples) * (1 - val_ratio)) + fp_split = int(len(fp_samples) * (1 - val_ratio)) + + train = tp_samples[:tp_split] + fp_samples[:fp_split] + val = tp_samples[tp_split:] + fp_samples[fp_split:] + + random.shuffle(train) + random.shuffle(val) + return train, val +``` + +- [ ] **步骤 4:运行测试验证通过** + +```bash +python3 -m pytest tests/test_classifier.py::test_load_labeled_dataset tests/test_classifier.py::test_train_val_split -v +``` +预期:2 PASSED + +- [ ] **步骤 5:Commit** + +```bash +git add tests/test_classifier.py tools/classifier_utils.py +git commit -m "feat: add labeled dataset loader for binary classifier" +``` + +--- + +### 任务 2:模型构造 + 前向传播测试 + +**文件:** +- 修改:`tests/test_classifier.py`(追加测试) +- 修改:`tools/classifier_utils.py`(追加函数) + +- [ ] **步骤 1:编写失败的测试** + +在 `tests/test_classifier.py` 末尾追加: + +```python +def test_create_model_output_shape(): + """模型输出形状正确:batch_size=4, num_classes=2.""" + from tools.classifier_utils import create_model + import torch + + model = create_model(num_classes=2, pretrained=True) + model.eval() + + dummy_input = torch.randn(4, 3, 224, 224) + with torch.no_grad(): + output = model(dummy_input) + + assert output.shape == (4, 2) +``` + +```python +def test_get_transforms_train_inference(): + """训练transforms有数据增强,推理transforms没有.""" + from tools.classifier_utils import get_transforms + + train_transform = get_transforms(is_train=True) + eval_transform = get_transforms(is_train=False) + + import torchvision.transforms as T + # Verify train has RandomHorizontalFlip + has_random = any(isinstance(t, T.RandomHorizontalFlip) for t in train_transform.transforms) + assert has_random, "train transform should have augmentation" + + # Verify eval does not + has_random_eval = any(isinstance(t, T.RandomHorizontalFlip) for t in eval_transform.transforms) + assert not has_random_eval, "eval transform should not have augmentation" +``` + +- [ ] **步骤 2:运行测试验证失败** + +```bash +python3 -m pytest tests/test_classifier.py::test_create_model_output_shape -v +``` +预期:`AttributeError: module 'tools.classifier_utils' has no attribute 'create_model'` + +- [ ] **步骤 3:编写最小实现** + +在 `tools/classifier_utils.py` 末尾追加: + +```python +def create_model(num_classes=2, pretrained=True): + """Create a ResNet18 binary classifier.""" + import torch + import torch.nn as nn + import torchvision.models as models + + model = models.resnet18(weights="IMAGENET1K_V1" if pretrained else None) + in_features = model.fc.in_features + model.fc = nn.Linear(in_features, num_classes) + return model + + +def get_transforms(is_train=True): + """Get image transforms for training or inference.""" + import torchvision.transforms as T + + if is_train: + return T.Compose([ + T.ToPILImage(), + T.Resize((224, 224)), + T.RandomHorizontalFlip(p=0.5), + T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + else: + return T.Compose([ + T.ToPILImage(), + T.Resize((224, 224)), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) +``` + +- [ ] **步骤 4:运行测试验证通过** + +```bash +python3 -m pytest tests/test_classifier.py::test_create_model_output_shape tests/test_classifier.py::test_get_transforms_train_inference -v +``` +预期:2 PASSED + +- [ ] **步骤 5:Commit** + +```bash +git add tests/test_classifier.py tools/classifier_utils.py +git commit -m "feat: add ResNet18 model creation and image transforms" +``` + +--- + +### 任务 3:分类器推理函数测试 + +**文件:** +- 修改:`tests/test_classifier.py`(追加测试) +- 修改:`tools/classifier_utils.py`(追加函数) + +- [ ] **步骤 1:编写失败的测试** + +在 `tests/test_classifier.py` 末尾追加: + +```python +def test_classify_image_returns_label_and_confidence(): + """单图分类返回 (label, confidence) 元组.""" + from tools.classifier_utils import create_model, get_transforms, classify_image + import torch + + model = create_model(num_classes=2, pretrained=True) + model.eval() + transform = get_transforms(is_train=False) + + # Create a dummy image + import cv2 + import tempfile + with tempfile.NamedTemporaryFile(suffix=".jpg") as f: + dummy = np.zeros((200, 200, 3), dtype=np.uint8) + cv2.imwrite(f.name, dummy) + label, conf = classify_image(model, f.name, transform, device="cpu") + + assert label in (0, 1) + assert 0.0 <= conf <= 1.0 + assert isinstance(conf, float) +``` + +```python +def test_classify_directory_writes_csv(tmp_path): + """批量分类目录输出CSV,包含所有裁剪图.""" + from tools.classifier_utils import create_model, get_transforms, classify_directory + import cv2 + + model = create_model(num_classes=2, pretrained=True) + model.eval() + transform = get_transforms(is_train=False) + + # Create 3 fake crops + for i in range(3): + img = np.zeros((100, 100, 3), dtype=np.uint8) + cv2.imwrite(str(tmp_path / f"crop_{i}_c0.50_vid.jpg"), img) + + output_csv = str(tmp_path / "results.csv") + classify_directory(model, str(tmp_path), output_csv, transform, device="cpu") + + with open(output_csv) as f: + reader = csv.DictReader(f) + rows = list(reader) + assert len(rows) == 3 + assert "filename" in rows[0] + assert "label" in rows[0] + assert "confidence" in rows[0] + assert rows[0]["label"] in ("0", "1") +``` + +- [ ] **步骤 2:运行测试验证失败** + +```bash +python3 -m pytest tests/test_classifier.py::test_classify_image_returns_label_and_confidence -v +``` +预期:`AttributeError: module 'tools.classifier_utils' has no attribute 'classify_image'` + +- [ ] **步骤 3:编写最小实现** + +在 `tools/classifier_utils.py` 末尾追加: + +```python +def classify_image(model, image_path, transform, device="cuda"): + """Classify a single image. Returns (label, confidence).""" + import torch + import cv2 + + img = cv2.imread(str(image_path)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + tensor = transform(img).unsqueeze(0).to(device) + + with torch.no_grad(): + output = model(tensor) + probs = torch.softmax(output, dim=1) + conf, pred = probs.max(dim=1) + + return pred.item(), conf.item() + + +def classify_directory(model, img_dir, output_csv, transform, device="cuda"): + """Classify all images in a directory, write results to CSV.""" + import csv + from pathlib import Path + + img_dir = Path(img_dir) + results = [] + + image_files = sorted( + p for p in img_dir.iterdir() + if p.suffix.lower() in (".jpg", ".jpeg", ".png") and p.is_file() + ) + + for img_path in image_files: + label, conf = classify_image(model, str(img_path), transform, device=device) + results.append((img_path.name, label, conf)) + + with open(output_csv, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["filename", "label", "confidence"]) + for filename, label, conf in results: + writer.writerow([filename, label, f"{conf:.4f}"]) + + return len(results) +``` + +- [ ] **步骤 4:运行测试验证通过** + +```bash +python3 -m pytest tests/test_classifier.py::test_classify_image_returns_label_and_confidence tests/test_classifier.py::test_classify_directory_writes_csv -v +``` +预期:2 PASSED + +- [ ] **步骤 5:Commit** + +```bash +git add tests/test_classifier.py tools/classifier_utils.py +git commit -m "feat: add single-image and directory classifier inference" +``` + +--- + +### 任务 4:训练脚本(训练 + 验证循环) + +**文件:** +- 创建:`tools/train_frisbee_classifier.py` +- 修改:`tests/test_classifier.py`(追加冒烟测试) + +- [ ] **步骤 1:编写失败的冒烟测试** + +在 `tests/test_classifier.py` 末尾追加: + +```python +def test_training_can_overfit_small_batch(): + """训练能在10张图上过拟合(loss接近0,acc接近1.0).""" + from tools.classifier_utils import create_model, get_transforms + import torch + import tempfile + import cv2 + import csv + + model = create_model(num_classes=2, pretrained=True) + + # Create 10 labeled images in temp dir + with tempfile.TemporaryDirectory() as tmpdir: + tmp = Path(tmpdir) + for i in range(5): + img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + cv2.imwrite(str(tmp / f"tp_{i}.jpg"), img) + for i in range(5): + img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + cv2.imwrite(str(tmp / f"fp_{i}.jpg"), img) + + csv_path = tmp / "review.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["filename", "result", "frame", "conf"]) + for i in range(5): + writer.writerow([f"tp_{i}.jpg", "TP", str(i), "0.5"]) + for i in range(5): + writer.writerow([f"fp_{i}.jpg", "FP", str(i+5), "0.5"]) + + # Use the training function directly + from tools.classifier_utils import load_labeled_dataset, split_train_val + samples = load_labeled_dataset(csv_path, tmp) + train_samples, val_samples = split_train_val(samples, val_ratio=0.2) + + # Train for 50 epochs on this tiny set + from tools.train_frisbee_classifier import train_one_epoch, validate + import torch.optim as optim + import torch.nn as nn + + model = create_model(num_classes=2, pretrained=False) + device = "cpu" + model.to(device) + + transform = get_transforms(is_train=True) + optimizer = optim.Adam(model.parameters(), lr=1e-3) + criterion = nn.CrossEntropyLoss() + + for epoch in range(100): + train_loss, train_acc = train_one_epoch( + model, train_samples, transform, optimizer, criterion, device, batch_size=4 + ) + val_loss, val_acc = validate( + model, val_samples, get_transforms(is_train=False), criterion, device, batch_size=4 + ) + + # Should overfit + assert train_acc > 0.9, f"train acc {train_acc:.2f} should be > 0.9" +``` + +- [ ] **步骤 2:运行测试验证失败** + +```bash +python3 -m pytest tests/test_classifier.py::test_training_can_overfit_small_batch -v +``` +预期:`ModuleNotFoundError: No module named 'tools.train_frisbee_classifier'` + +- [ ] **步骤 3:编写实现** + +创建 `tools/train_frisbee_classifier.py`: + +```python +"""Train a frisbee binary classifier on manually-labeled detection crops.""" +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import argparse +import random +import torch +import torch.nn as nn +import torch.optim as optim +import cv2 +from classifier_utils import ( + load_labeled_dataset, split_train_val, create_model, get_transforms +) + + +def train_one_epoch(model, samples, transform, optimizer, criterion, device, batch_size=16): + """Train for one epoch. Uses oversampling: ensures batch has both classes.""" + model.train() + tp_samples = [(p, l) for p, l in samples if l == 1] + fp_samples = [(p, l) for p, l in samples if l == 0] + + total_loss = 0.0 + correct = 0 + total = 0 + + # Oversample TP class to balance batches + num_batches = max(len(tp_samples), len(fp_samples)) // (batch_size // 2) + if num_batches == 0: + num_batches = 1 + + for _ in range(num_batches): + # Pick half batch from each class + batch_tp = random.choices(tp_samples, k=batch_size // 2) + batch_fp = random.choices(fp_samples, k=batch_size // 2) + batch = batch_tp + batch_fp + random.shuffle(batch) + + images = [] + labels = [] + for img_path, label in batch: + img = cv2.imread(str(img_path)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + tensor = transform(img) + images.append(tensor) + labels.append(label) + + x = torch.stack(images).to(device) + y = torch.tensor(labels, dtype=torch.long).to(device) + + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + + total_loss += loss.item() + _, preds = output.max(dim=1) + correct += (preds == y).sum().item() + total += y.size(0) + + return total_loss / num_batches, correct / total + + +def validate(model, samples, transform, criterion, device, batch_size=16): + """Validate on a set of samples.""" + model.eval() + total_loss = 0.0 + correct = 0 + total = 0 + num_batches = 0 + + for i in range(0, len(samples), batch_size): + batch = samples[i:i + batch_size] + images = [] + labels = [] + for img_path, label in batch: + img = cv2.imread(str(img_path)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + tensor = transform(img) + images.append(tensor) + labels.append(label) + + if not images: + continue + + x = torch.stack(images).to(device) + y = torch.tensor(labels, dtype=torch.long).to(device) + + with torch.no_grad(): + output = model(x) + loss = criterion(output, y) + _, preds = output.max(dim=1) + + total_loss += loss.item() + correct += (preds == y).sum().item() + total += y.size(0) + num_batches += 1 + + return total_loss / max(num_batches, 1), correct / max(total, 1) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--csv", default="data/perbox_crops/review_results.csv") + parser.add_argument("--img-dir", default="data/perbox_crops") + parser.add_argument("--epochs", type=int, default=50) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--name", default="frisbee_classifier_v1") + parser.add_argument("--val-ratio", type=float, default=0.2) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", default="cuda") + args = parser.parse_args() + + random.seed(args.seed) + torch.manual_seed(args.seed) + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + # Load data + csv_path = Path(args.csv) + img_dir = Path(args.img_dir) + samples = load_labeled_dataset(csv_path, img_dir) + print(f"Loaded {len(samples)} samples ({sum(1 for _,l in samples if l==1)} TP, {sum(1 for _,l in samples if l==0)} FP)") + + train_samples, val_samples = split_train_val(samples, val_ratio=args.val_ratio, seed=args.seed) + print(f"Train: {len(train_samples)} ({sum(1 for _,l in train_samples if l==1)} TP, {sum(1 for _,l in train_samples if l==0)} FP)") + print(f"Val: {len(val_samples)} ({sum(1 for _,l in val_samples if l==1)} TP, {sum(1 for _,l in val_samples if l==0)} FP)") + + # Create model + model = create_model(num_classes=2, pretrained=True) + model.to(device) + + # Class weights for imbalance + tp_count = sum(1 for _, l in train_samples if l == 1) + fp_count = sum(1 for _, l in train_samples if l == 0) + tp_weight = fp_count / max(tp_count, 1) + class_weights = torch.tensor([1.0, tp_weight], device=device) + criterion = nn.CrossEntropyLoss(weight=class_weights) + + # Separate learning rates + optimizer = optim.Adam([ + {"params": model.fc.parameters(), "lr": args.lr * 10}, + {"params": [p for n, p in model.named_parameters() if "fc" not in n], "lr": args.lr}, + ]) + + transform_train = get_transforms(is_train=True) + transform_val = get_transforms(is_train=False) + + best_val_acc = 0.0 + save_path = Path("runs/classify") / args.name + save_path.mkdir(parents=True, exist_ok=True) + + for epoch in range(args.epochs): + train_loss, train_acc = train_one_epoch( + model, train_samples, transform_train, optimizer, criterion, device, args.batch_size + ) + val_loss, val_acc = validate( + model, val_samples, transform_val, criterion, device, args.batch_size + ) + + print(f"Epoch {epoch+1:3d}/{args.epochs} " + f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} " + f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}") + + if val_acc > best_val_acc: + best_val_acc = val_acc + torch.save({ + "model_state_dict": model.state_dict(), + "val_acc": val_acc, + "args": vars(args), + }, save_path / f"{args.name}.pt") + print(f" => saved (val_acc={val_acc:.4f})") + + print(f"\nBest val acc: {best_val_acc:.4f}") + print(f"Model saved to {save_path / args.name}.pt") + + +if __name__ == "__main__": + main() +``` + +- [ ] **步骤 4:运行冒烟测试验证** + +```bash +python3 -m pytest tests/test_classifier.py::test_training_can_overfit_small_batch -v +``` +预期:PASS(可能有一定波动,acc需 > 0.9) + +- [ ] **步骤 5:Run all tests to confirm nothing broken** + +```bash +python3 -m pytest tests/test_classifier.py -v +``` +预期:全部 PASS + +- [ ] **步骤 6:Commit** + +```bash +git add tools/train_frisbee_classifier.py tests/test_classifier.py +git commit -m "feat: add classifier training script with oversampling" +``` + +--- + +### 任务 5:推理脚本(批量分类 v7_quick 裁剪图) + +**文件:** +- 创建:`tools/classify_crops.py` + +- [ ] **步骤 1:创建推理入口脚本** + +`tools/classify_crops.py` 无需单独测试(`classify_directory` 已在任务3中测试)。直接编写: + +```python +"""Classify detection crops as TP/FP using trained binary classifier.""" +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import argparse +import torch +from classifier_utils import create_model, get_transforms, classify_directory + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, help="Path to trained .pt model file") + parser.add_argument("--crop-dir", required=True, help="Directory with crop images") + parser.add_argument("--output", default=None, help="Output CSV path (default: /classify_results.csv)") + parser.add_argument("--device", default="cuda") + args = parser.parse_args() + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + model = create_model(num_classes=2, pretrained=False) + checkpoint = torch.load(args.model, map_location=device, weights_only=True) + model.load_state_dict(checkpoint["model_state_dict"]) + model.to(device) + model.eval() + + val_acc = checkpoint.get("val_acc", "unknown") + print(f"Model loaded (val_acc={val_acc})") + + transform = get_transforms(is_train=False) + crop_dir = Path(args.crop_dir) + output = Path(args.output) if args.output else crop_dir / "classify_results.csv" + + n = classify_directory(model, str(crop_dir), str(output), transform, device=device) + print(f"Classified {n} images") + print(f"Results saved to {output}") + + # Print summary + if output.exists(): + import csv + tp = fp = 0 + with open(output) as f: + for row in csv.DictReader(f): + if row["label"] == "1": + tp += 1 + else: + fp += 1 + total = tp + fp + print(f"TP={tp} FP={fp} FP_rate={fp/total*100:.1f}%") + + +if __name__ == "__main__": + main() +``` + +- [ ] **步骤 2:Commit** + +```bash +git add tools/classify_crops.py +git commit -m "feat: add batch crop classifier script" +``` + +--- + +### 任务 6:运行实际训练 + 分类 + +- [ ] **步骤 1:训练分类器** + +```bash +python3 tools/train_frisbee_classifier.py \ + --csv data/perbox_crops/review_results.csv \ + --img-dir data/perbox_crops \ + --epochs 50 --batch-size 16 --lr 1e-4 \ + --name frisbee_classifier_v1 +``` +预期:训练完成,val_acc > 0.75 + +- [ ] **步骤 2:分类 55-56min 裁剪图** + +```bash +python3 tools/classify_crops.py \ + --model runs/classify/frisbee_classifier_v1/frisbee_classifier_v1.pt \ + --crop-dir data/fp_spotcheck_v7_55min +``` +预期:输出 `data/fp_spotcheck_v7_55min/classify_results.csv`,显示 TP/FP 统计 + +- [ ] **步骤 3:分类 20-23min 裁剪图** + +```bash +python3 tools/classify_crops.py \ + --model runs/classify/frisbee_classifier_v1/frisbee_classifier_v1.pt \ + --crop-dir data/fp_spotcheck_v7_20min +``` +预期:输出 `data/fp_spotcheck_v7_20min/classify_results.csv` + +- [ ] **步骤 4:汇总两个视频的结果** + +```bash +python3 -c " +import csv +from pathlib import Path + +for name in ['55min', '20min']: + csv_path = Path(f'data/fp_spotcheck_v7_{name}/classify_results.csv') + if csv_path.exists(): + tp = fp = 0 + with open(csv_path) as f: + for row in csv.DictReader(f): + if row['label'] == '1': tp += 1 + else: fp += 1 + total = tp + fp + print(f'v7_{name}: TP={tp} FP={fp} FP_rate={fp/total*100:.1f}% (n={total})') +" +``` +预期:显示每个视频的FP率 + +- [ ] **步骤 5:Commit 训练结果(仅代码和元数据,不含模型)** + +```bash +git add runs/classify/frisbee_classifier_v1/ # 如果有的话 +git commit -m "results: classifier v1 training and crop classification" --allow-empty +``` + +--- + +### 任务 7:分析结果 + 决策 + +- [ ] **步骤 1:获取准确的FP率** + +运行任务6的汇总命令,得到每个视频的TP/FP统计。 + +- [ ] **步骤 2:决策门** + +根据FP率决定: +- **FP率 < 20%** → 直接进入Phase 2生产训练(imgsz=1280, 100 epochs) +- **FP率 ≥ 20%** → 将分类器标为FP的裁剪图复制为硬负样本,编写 `tools/expand_hard_negatives.py`,重新训练v8 + +--- + +## 自检 + +**1. 规格覆盖度:** +- [x] 数据加载:`load_labeled_dataset`, `split_train_val` → 任务1 +- [x] 模型:`create_model` (ResNet18) → 任务2 +- [x] 图像变换:`get_transforms` → 任务2 +- [x] 单图分类:`classify_image` → 任务3 +- [x] 批量分类:`classify_directory` → 任务3 +- [x] 训练循环:`train_one_epoch`, `validate`, `main` → 任务4 +- [x] 推理脚本:`classify_crops.py` → 任务5 +- [x] 实际运行 + 决策 → 任务6, 7 + +**2. 占位符扫描:** 无 TODO/待定/补充细节。所有代码步骤都有完整实现。 + +**3. 类型一致性:** +- `samples` 类型:`list of (Path, int)` 贯穿全部函数 +- `transform` 类型:`torchvision.transforms.Compose` 贯穿全部函数 +- `model` 类型:`nn.Module` 贯穿全部函数 +- 标签约定:`0=FP, 1=TP` 一致 diff --git a/docs/superpowers/plans/2026-05-14-v7-data-driven-precision.md b/docs/superpowers/plans/2026-05-14-v7-data-driven-precision.md new file mode 100644 index 0000000..5d18cf5 --- /dev/null +++ b/docs/superpowers/plans/2026-05-14-v7-data-driven-precision.md @@ -0,0 +1,295 @@ +# v7 Data-Driven Precision — 实现计划 + +> **面向 AI 代理的工作者:** 必需子技能:使用 superpowers:subagent-driven-development(推荐)或 superpowers:executing-plans 逐任务实现此计划。步骤使用复选框(`- [ ]`)语法来跟踪进度。 + +**目标:** 通过数据驱动方法(替换背景为bbox级hard negatives + 添加TP帧)将per-box FP率从89%降至<20%,同时保持帧检测率>=50%。 + +**架构:** 1个新脚本(数据预处理)+ 2个运维步骤(训练 + 评估)。先快速验证(imgsz=640, ~10-15min),通过后再精训(imgsz=1280)。 + +**技术栈:** Python 3.10, ultralytics (YOLOv8), OpenCV, PIL + +**设计规格:** `docs/superpowers/specs/2026-05-14-v7-data-driven-precision-design.md` + +--- + +## 文件变更 + +| 文件 | 变更 | 职责 | +|------|------|------| +| `tools/prep_v7_data.py` | **新增** | 数据预处理:删背景、复制FP crops、提取TP帧 | +| `models/train.py` | 无变更 | 已有 `--resume` 参数可从v3 fine-tune | +| `configs/frisbee_merged.yaml` | 自动重新生成 | `prep_v7_data.py` 调用 `write_yaml_config()` | + +无测试文件要求(预处理脚本有 `--dry-run` 模式验证)。 + +--- + +### 任务 1:编写数据预处理脚本 `tools/prep_v7_data.py` + +**文件:** +- 新增:`tools/prep_v7_data.py` + +**分析:** 需要完成三个数据操作:(1) 删除当前训练集中的283个空标签背景图,(2) 复制178个FP crops作为新背景,(3) 从55-56min测试视频中提取19个TP帧并生成YOLO标签。 + +裁剪图文件名格式:`crop_NNNN_fXXXXX_cC.CONF.jpg`,其中XXXXX是帧号。TP裁剪图覆盖19个唯一帧(1,2,4,5,6,7,8,9,59,60,61,62,63,64,65,66,68,94,96)。 + +- [ ] **步骤 1:创建 `tools/prep_v7_data.py` 脚本** + +```python +"""v7 data preparation: replace backgrounds with FP crops + add TP frames. + +Usage: + python3 tools/prep_v7_data.py # execute + python3 tools/prep_v7_data.py --dry-run # preview only +""" +``` + +脚本结构: + +1. **argparse**: `--dry-run`(仅打印将执行的操作),`--video`(默认55-56min视频路径),`--model`(默认v3模型路径,用于TP帧标签生成) + +2. **删除283个背景图**: + - 遍历 `data/datasets/frisbee_merged/labels/train/*.txt` + - 找到文件大小为0的标签文件(背景图) + - 删除对应的图片文件和标签文件 + - 统计删除数量,打印 + +3. **复制178个FP crops**: + - 读取 `data/perbox_crops/review_results.csv` + - 过滤 `result == "FP"` 的行 + - 将每个FP crop复制到 `data/datasets/frisbee_merged/images/train/`,命名为 `hardneg_v7_{original_filename}` + - 创建对应的空标签文件 `hardneg_v7_{stem}.txt` 在 `labels/train/` + +4. **提取19个TP帧 + 生成YOLO标签**: + - 从CSV中读取所有 `result == "TP"` 的行,收集唯一帧号 + - 用OpenCV打开55-56min视频,逐帧读取并保存到 `images/train/tp_v7_{frame_num:05d}.jpg` + - 用v3模型对每个TP帧做推理(`model.predict(frame, conf=0.20, imgsz=1280)`) + - 将预测结果转为YOLO标签格式(`class cx cy w h`),保存到 `labels/train/tp_v7_{frame_num:05d}.txt` + - **重要**:只保存class 0(frisbee)的预测框,conf >= 0.35 + +5. **重新生成YAML配置**: + - 调用 `utils.dataset.write_yaml_config()` 更新 `configs/frisbee_merged.yaml` + +6. **最终统计**: + - 打印:删除的背景数、新增的FP crop数、新增的TP帧数、最终训练集总大小、正负样本比例 + +- [ ] **步骤 2:Dry-run测试** + +```bash +python3 tools/prep_v7_data.py --dry-run +``` + +验证输出: +- 确认找到283个待删除背景 +- 确认178个FP crops将被复制 +- 确认19个TP帧将被提取 +- 确认不会修改任何文件 + +- [ ] **步骤 3:实际执行预处理** + +```bash +python3 tools/prep_v7_data.py +``` + +验证: +- 打印的统计数字合理(~1698张训练图,~11.8%背景) +- `data/datasets/frisbee_merged/images/train/` 包含 `hardneg_v7_*` 和 `tp_v7_*` 文件 +- `data/datasets/frisbee_merged/labels/train/` 包含对应标签文件 +- `tp_v7_*` 标签文件非空(包含YOLO格式bbox) + +```bash +ls data/datasets/frisbee_merged/images/train/ | wc -l +ls data/datasets/frisbee_merged/images/train/ | grep "^hardneg_v7_" | wc -l +ls data/datasets/frisbee_merged/images/train/ | grep "^tp_v7_" | wc -l +ls data/datasets/frisbee_merged/labels/train/ | grep "^tp_v7_" | xargs -I{} sh -c 'test -s data/datasets/frisbee_merged/labels/train/{} && echo "OK" || echo "EMPTY"' +``` + +--- + +### 任务 2:Phase 1 快速训练(imgsz=640, ~10-15min) + +**文件:** +- 无代码变更。纯运维操作。 + +**重要约束:** +- YOLO训练 >10min,**必须**在tmux中运行 +- 双重嵌套bug:`project="runs/detect"` 导致输出到 `runs/detect/runs/detect/` +- 使用 `--resume` 从v3 fine-tune(不是从头训练) + +- [ ] **步骤 1:确认v3模型存在** + +```bash +ls -lh runs/detect/frisbee_det_s_v3/weights/best.pt +``` + +- [ ] **步骤 2:启动tmux训练会话** + +```bash +tmux new-session -d -s train -c /mnt/e/frisbee-detector +tmux send-keys -t train "python3 models/train.py \ + --data configs/frisbee_merged.yaml \ + --resume runs/detect/frisbee_det_s_v3/weights/best.pt \ + --model-size s \ + --name frisbee_det_s_v7_quick \ + --epochs 30 --imgsz 640 --batch 2 \ + --patience 10 --box 5 --close-mosaic 10 \ + --workers 2" Enter +``` + +确认启动: +```bash +tmux ls +``` + +- [ ] **步骤 3:监控训练启动** + +等待30秒后检查: +```bash +sleep 30 && tmux capture-pane -t train -p | tail -20 +``` + +- [ ] **步骤 4:等待训练完成** + +定期检查(每2-3分钟): +```bash +tmux capture-pane -t train -p | tail -5 +``` + +训练完成标志:tmux会话不再活跃,或输出显示 "Results" / "best.pt"。 + +- [ ] **步骤 5:修复双重嵌套路径** + +```bash +if [ -d "runs/detect/runs/detect/frisbee_det_s_v7_quick" ]; then + mv runs/detect/runs/detect/frisbee_det_s_v7_quick runs/detect/frisbee_det_s_v7_quick +fi +``` + +- [ ] **步骤 6:确认模型文件** + +```bash +ls -lh runs/detect/frisbee_det_s_v7_quick/weights/best.pt +``` + +预期:~22MB。 + +--- + +### 任务 3:Phase 1 评估 + +**文件:** +- 无代码变更。纯运维操作。 + +- [ ] **步骤 1:验证集评估** + +```bash +python3 models/train.py --validate-only \ + --model-path runs/detect/frisbee_det_s_v7_quick/weights/best.pt \ + --data configs/frisbee_merged.yaml --imgsz 640 +``` + +记录:mAP50, Precision, Recall。 + +- [ ] **步骤 2:55-56min测试视频评估** + +```bash +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v7_quick/weights/best.pt \ + --video movie/25866279684-1-192_55-56min.mp4 --conf 0.35 +``` + +记录:帧检测率、平均dets/帧、置信度分布。 + +- [ ] **步骤 3:20-23min测试视频评估** + +```bash +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v7_quick/weights/best.pt \ + --video movie/clip_20-23min.mp4 --conf 0.35 +``` + +记录:同上。 + +- [ ] **步骤 4:决策** + +对比阈值: + +| 指标 | v3基线 | v7目标 | v7实际 | 通过? | +|------|:------:|:------:|:------:|:-----:| +| 55-56min帧检测率 | 69.9% | >=50% | ? | ? | +| 20-23min帧检测率 | 61.4% | >=50% | ? | ? | +| 平均dets/帧(55-56) | 1.86 | <1.0 | ? | ? | +| Per-box FP率 | 89% | <20% | ? | ? | + +**通过** → 进入任务4(Phase 2精训) +**未通过** → 停止,分析失败原因,考虑VLM扩增或调整数据策略 + +注意:per-box FP率需要手动抽查。从每个测试视频的检测结果中随机采样50个检测框,人工判断是否为飞盘。 + +--- + +### 任务 4:Phase 2 生产训练(imgsz=1280, ~1-2h) + +**仅当任务3所有指标通过时执行。** + +- [ ] **步骤 1:启动tmux精训会话** + +```bash +tmux new-session -d -s train -c /mnt/e/frisbee-detector +tmux send-keys -t train "python3 models/train.py \ + --data configs/frisbee_merged.yaml \ + --resume runs/detect/frisbee_det_s_v3/weights/best.pt \ + --model-size s \ + --name frisbee_det_s_v7 \ + --epochs 100 --imgsz 1280 --batch 2 \ + --patience 20 --box 5 --close-mosaic 10 \ + --workers 2" Enter +``` + +- [ ] **步骤 2:修复双重嵌套路径** + +```bash +if [ -d "runs/detect/runs/detect/frisbee_det_s_v7" ]; then + mv runs/detect/runs/detect/frisbee_det_s_v7 runs/detect/frisbee_det_s_v7 +fi +``` + +- [ ] **步骤 3:评估(同任务3流程)** + +在两个测试视频上用conf=0.35评估,记录所有指标。 + +- [ ] **步骤 4:更新 `configs/models.py`** + +如果v7优于v3: +- 更新 `DEFAULT_MODEL` 指向v7 +- 更新 `DEFAULT_CONF` 如果最优阈值变化 + +- [ ] **步骤 5:更新 AGENTS.md** + +在模型命名约定表中添加v7行: +``` +frisbee_det_s_v7 → v7 (data-driven FP reduction, 178 hard neg crops + 22 TP frames) +``` + +--- + +## 回滚方案 + +如果v7结果不如v3: +1. 恢复原始训练数据:重新运行 `python3 tools/merge_datasets.py`(从源数据集重建) +2. v3模型文件未受影响(训练不修改原始模型) +3. v7 quick和v7模型可以删除:`rm -rf runs/detect/frisbee_det_s_v7*` + +--- + +## 预估时间 + +| 任务 | 预估时间 | 依赖 | +|------|:--------:|------| +| 任务1:编写+测试预处理脚本 | 10-15min | 无 | +| 任务1:执行预处理 | 2-3min | 脚本就绪 | +| 任务2:快速训练 | 10-15min | 数据就绪 | +| 任务3:评估 | 5-10min | 模型就绪 | +| 任务4:精训(如果通过) | 1-2h | Phase 1通过 | +| **总计(Phase 1)** | **~30-45min** | | +| **总计(含Phase 2)** | **~2.5-3h** | | diff --git a/docs/superpowers/plans/2026-05-18-tracking-pipeline.md b/docs/superpowers/plans/2026-05-18-tracking-pipeline.md new file mode 100644 index 0000000..bdfdb36 --- /dev/null +++ b/docs/superpowers/plans/2026-05-18-tracking-pipeline.md @@ -0,0 +1,425 @@ +# 追踪 + 场地坐标映射 Pipeline 实现计划 + +> **面向 AI 代理的工作者:** 必需子技能:使用 superpowers:subagent-driven-development(推荐)或 superpowers:executing-plans 逐任务实现此计划。步骤使用复选框(`- [ ]`)语法来跟踪进度。 + +**目标:** 创建 `inference/predict_track.py`,在 v3 检测上跑 ByteTrack 追踪,用 Homography 映射场地坐标,输出标注视频 + CSV 轨迹数据。 + +**架构:** 单文件脚本:argparse → YOLO.track() → pixel_to_world() → OpenCV 标注 + CSV 导出。坐标映射复用 `utils/homography.py`。 + +**技术栈:** Python 3, ultralytics (ByteTrack), OpenCV, NumPy + +**规格文档:** `docs/superpowers/specs/2026-05-18-tracking-pipeline-design.md` + +--- + +## 文件结构 + +| 文件 | 操作 | 职责 | +|------|------|------| +| `inference/predict_track.py` | 新建 | 追踪 pipeline 主入口 | +| `tests/test_tracking.py` | 新建 | 辅助函数单元测试 | +| `configs/models.py` | 修改 | 更新 DEFAULT_MODEL 为 v3 | + +--- + +### 任务 1:更新默认模型为 v3 + +**文件:** +- 修改:`configs/models.py` + +- [ ] **步骤 1:修改 `configs/models.py`** + +```python +DEFAULT_MODEL = V3_MODEL +DEFAULT_CONF = 0.35 +``` + +- [ ] **步骤 2:验证** + +运行:`python3 -c "from configs.models import DEFAULT_MODEL, DEFAULT_CONF; print(DEFAULT_MODEL.name, DEFAULT_CONF)"` +预期:输出包含 `frisbee_det_s_v3 0.35` + +- [ ] **步骤 3:Commit** + +```bash +git add configs/models.py +git commit -m "fix: set DEFAULT_MODEL to v3, DEFAULT_CONF to 0.35" +``` + +--- + +### 任务 2:创建辅助函数测试 + +**文件:** +- 创建:`tests/test_tracking.py` + +- [ ] **步骤 1:编写测试** + +```python +"""Tests for tracking pipeline helper functions.""" + +import datetime +import os +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def test_calibration_lookup_exact_match(tmp_path, monkeypatch): + """Simulate find_calibration logic: exact stem match.""" + calib_dir = tmp_path / "homography" + calib_dir.mkdir(parents=True) + (calib_dir / "test_video.json").write_text('{"matrix":[[1,0,0],[0,1,0],[0,0,1]],' + '"video":"test.mp4","image_size":[1280,720],' + '"field_size_m":[100,37],"calibration_frame":0,' + '"points":[],"reprojection_error_px":0.5}') + + stem = "test_video" + candidate = calib_dir / f"{stem}.json" + assert candidate.exists() + + import json + data = json.loads(candidate.read_text()) + assert data["video"] == "test.mp4" + + +def test_calibration_lookup_underscore_fallback(tmp_path): + """Simulate find_calibration: fallback to base before first underscore.""" + calib_dir = tmp_path / "homography" + calib_dir.mkdir(parents=True) + (calib_dir / "base.json").write_text('{"matrix":[[1,0,0],[0,1,0],[0,0,1]],' + '"video":"base.mp4","image_size":[1280,720],' + '"field_size_m":[100,37],"calibration_frame":0,' + '"points":[],"reprojection_error_px":0.5}') + + video_stem = "base_55-56min" + base = video_stem.split("_", 1)[0] if "_" in video_stem else video_stem + assert base == "base" + candidate = calib_dir / f"{base}.json" + assert candidate.exists() + + import json + data = json.loads(candidate.read_text()) + assert data["video"] == "base.mp4" + + +def test_calibration_lookup_not_found(tmp_path): + calib_dir = tmp_path / "homography" + calib_dir.mkdir() + + video_stem = "nonexistent" + assert not (calib_dir / f"{video_stem}.json").exists() + + base = video_stem.split("_", 1)[0] + assert not (calib_dir / f"{base}.json").exists() + + +def test_export_csv_rows(tmp_path): + """Test CSV export logic.""" + rows = [ + {"frame": 0, "track_id": 1, "px": 640.0, "py": 500.0, "wx": "", "wy": "", "conf": 0.52}, + {"frame": 1, "track_id": 1, "px": 642.0, "py": 498.0, "wx": 50.1, "wy": 1.2, "conf": 0.55}, + ] + csv_path = tmp_path / "tracks.csv" + import csv + csv_path.parent.mkdir(parents=True, exist_ok=True) + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["frame", "track_id", "px", "py", "wx", "wy", "conf"]) + writer.writeheader() + writer.writerows(rows) + + lines = csv_path.read_text().strip().split("\n") + assert len(lines) == 3 + assert lines[0] == "frame,track_id,px,py,wx,wy,conf" + assert lines[1].endswith("0.52") + assert "50.1" in lines[2] +``` + +- [ ] **步骤 2:运行测试** + +运行:`python3 -m pytest tests/test_tracking.py -v` +预期:4/4 PASS + +- [ ] **步骤 3:Commit** + +```bash +git add tests/test_tracking.py +git commit -m "test: add calibration lookup and CSV export tests" +``` + +--- + +### 任务 3:创建 predict_track.py + +**文件:** +- 创建:`inference/predict_track.py` + +- [ ] **步骤 1:创建完整文件** + +```python +"""ByteTrack tracking + field coordinate mapping. + +Runs YOLO inference with ByteTrack on a video, maps detections to field +coordinates via homography, outputs annotated video + CSV trajectory data. + +Usage: + python3 inference/predict_track.py --video movie/test.mp4 + python3 inference/predict_track.py --video movie/test.mp4 --no-visualize +""" + +import argparse +import csv +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import cv2 +from ultralytics import YOLO + +from configs.models import V3_MODEL, DEFAULT_CONF +from utils.homography import load_calibration, pixel_to_world + + +def find_calibration(video_path: Path) -> dict | None: + """Auto-detect calibration JSON for a video. Returns None if not found.""" + calib_dir = Path("configs/homography") + if not calib_dir.exists(): + return None + + stem = video_path.stem + candidate = calib_dir / f"{stem}.json" + if candidate.exists(): + return load_calibration(candidate) + + if "_" in stem: + base = stem.split("_", 1)[0] + candidate = calib_dir / f"{base}.json" + if candidate.exists(): + return load_calibration(candidate) + + return None + + +def export_tracks_csv(rows: list[dict], output_path: Path) -> None: + """Write tracking data to CSV.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["frame", "track_id", "px", "py", "wx", "wy", "conf"]) + writer.writeheader() + writer.writerows(rows) + + +def get_args(): + parser = argparse.ArgumentParser(description="ByteTrack tracking with field coordinate mapping") + parser.add_argument("--video", required=True, help="Path to input video") + parser.add_argument("--model", type=str, default=str(V3_MODEL)) + parser.add_argument("--calibration", type=str, default=None, help="Explicit calibration JSON path") + parser.add_argument("--conf", type=float, default=DEFAULT_CONF) + parser.add_argument("--output-dir", type=str, default=None) + parser.add_argument("--no-visualize", action="store_true", default=False, help="Skip video output") + return parser.parse_args() + + +def main(): + args = get_args() + video_path = Path(args.video) + if not video_path.exists(): + print(f"ERROR: Video not found: {video_path}") + sys.exit(1) + + print(f"Video: {video_path.name}") + print(f"Model: {args.model}") + print(f"Conf: {args.conf}") + + calib = None + matrix = None + if args.calibration: + calib = load_calibration(Path(args.calibration)) + matrix = calib["matrix"] + print(f"Calib: {args.calibration}") + else: + calib = find_calibration(video_path) + if calib: + matrix = calib["matrix"] + print("Calib: auto-detected") + else: + print("Calib: NONE — output will be pixel-only") + + output_dir = Path(args.output_dir or f"runs/track/{video_path.stem}") + output_dir.mkdir(parents=True, exist_ok=True) + print(f"Output: {output_dir}") + + model = YOLO(args.model) + print("\nTracking...") + + all_rows: list[dict] = [] + out_video = None + + cap = cv2.VideoCapture(str(video_path)) + fps = cap.get(cv2.CAP_PROP_FPS) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + if not args.no_visualize: + out_video = cv2.VideoWriter( + str(output_dir / f"{video_path.stem}_tracked.mp4"), + cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h), + ) + + results = model.track( + source=str(video_path), + conf=args.conf, + tracker="bytetrack.yaml", + persist=True, + stream=True, + verbose=False, + ) + + for frame_idx, r in enumerate(results, 1): + orig_frame = r.orig_img if hasattr(r, "orig_img") else None + if orig_frame is None: + continue + + if r.boxes is not None and r.boxes.id is not None: + for box, tid, conf_val in zip(r.boxes.xyxy, r.boxes.id, r.boxes.conf): + px = float((box[0] + box[2]) / 2.0) + py = float(box[3]) + tid_int = int(tid) + + row = {"frame": frame_idx, "track_id": tid_int, "px": round(px, 1), + "py": round(py, 1), "wx": "", "wy": "", "conf": round(float(conf_val), 4)} + + if matrix is not None: + try: + wx, wy = pixel_to_world(matrix, px, py) + row["wx"] = round(wx, 2) + row["wy"] = round(wy, 2) + except Exception: + pass + + all_rows.append(row) + + if out_video is not None: + x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3]) + cv2.rectangle(orig_frame, (x1, y1), (x2, y2), (0, 255, 0), 2) + tid_label = f"#{tid_int}" + cv2.putText(orig_frame, tid_label, (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) + if row["wx"] != "": + coord_text = f"({row['wx']},{row['wy']})m" + cv2.putText(orig_frame, coord_text, (x2 - 120, y2 + 15), + cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1) + + if out_video is not None: + info = f"Frame: {frame_idx} Dets: {len(r.boxes) if r.boxes else 0}" + cv2.putText(orig_frame, info, (10, 25), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2) + out_video.write(orig_frame) + + if frame_idx % 100 == 0: + print(f" {frame_idx} frames, {len(all_rows)} detections") + + if out_video is not None: + out_video.release() + + csv_path = output_dir / f"{video_path.stem}_tracks.csv" + export_tracks_csv(all_rows, csv_path) + dets_per_sec = len(all_rows) / max(frame_idx / fps, 0.001) if fps > 0 else 0 + print(f"\nDone: {frame_idx} frames, {len(all_rows)} detections ({dets_per_sec:.1f} dets/s)") + print(f"CSV: {csv_path}") + if out_video is not None: + print(f"Video: {output_dir / f'{video_path.stem}_tracked.mp4'}") + + +if __name__ == "__main__": + main() +``` + +- [ ] **步骤 2:编译验证** + +运行:`python3 -c "import py_compile; py_compile.compile('inference/predict_track.py', doraise=True)"` +预期:无输出 + +- [ ] **步骤 3:Commit** + +```bash +git add inference/predict_track.py +git commit -m "feat: add ByteTrack tracking + coordinate mapping pipeline" +``` + +--- + +### 任务 4:端到端验证 + +**文件:** 无修改 + +- [ ] **步骤 1:在 55-56min 视频上运行(像素-only,无标定)** + +```bash +python3 inference/predict_track.py \ + --video movie/25866279684-1-192_55-56min.mp4 +``` + +预期: +- 输出 `runs/track/25866279684-1-192_55-56min/` 目录 +- 包含 `.mp4` 和 `_tracks.csv` +- 控制台出现 "NONE — output will be pixel-only" +- CSV 中 wx, wy 为空 + +- [ ] **步骤 2:在 20-23min 视频上运行** + +```bash +python3 inference/predict_track.py \ + --video movie/clip_20-23min.mp4 +``` + +预期:同上 + +- [ ] **步骤 3:检查 CSV 质量** + +```bash +python3 -c " +import csv +for fn in ['runs/track/25866279684-1-192_55-56min/25866279684-1-192_55-56min_tracks.csv', + 'runs/track/clip_20-23min/clip_20-23min_tracks.csv']: + try: + with open(fn) as f: + rows = list(csv.DictReader(f)) + print(f'{fn}: {len(rows)} rows, {len(set(r[\"track_id\"] for r in rows))} tracks') + except FileNotFoundError: + print(f'{fn}: not found') +" +``` + +预期:行数合理(500-5000 行),track_ids 不为 0 + +--- + +### 任务 5:全部测试回归 + +**文件:** 无修改 + +- [ ] **步骤 1:运行完整测试套件** + +```bash +python3 -m pytest tests/ -v +``` + +预期:全部 PASS(约 4*8 + 4 = ~26 个测试) + +--- + +### 任务 6:清除中间产物 + +**文件:** 无修改 + +- [ ] **步骤 1:清理临时视频输出** + +```bash +rm -rf runs/track/ +``` diff --git a/docs/superpowers/specs/2026-05-13-improve-precision-v5-design.md b/docs/superpowers/specs/2026-05-13-improve-precision-v5-design.md new file mode 100644 index 0000000..1ff985f --- /dev/null +++ b/docs/superpowers/specs/2026-05-13-improve-precision-v5-design.md @@ -0,0 +1,147 @@ +# v5 Precision Improvement — Design Spec + +> **Goal:** Reduce false positives (white hats → frisbee) from ~60% to <15% +> while keeping frame detection rate at 60-70%. + +## Summary + +v3/v4 models achieve good recall (68-79%) but produce ~60% false positives — +mostly white hats and rocks mistaken for frisbees. The root cause is `cls` loss +weight at YOLOv8's default 0.5, which barely penalizes classification errors. + +**Solution:** Retrain with `cls=1.3` (Ultralytics official optimal value), +using the existing 4950-image merged dataset. If that's not enough, expand +with OpenImages Flying disc images as backup (v6). + +## Architecture + +``` +主线 ──→ 训 v5 (cls=1.3, 现有合并4950张, 其余参数同v4) + │ + ▼ + 评估 v5 vs v4(验证集 + 两个测试视频) + │ + ┌────┴────┐ + │ FP<15%? │ + ├────┬────┤ + │ 是 │ 否 │ + ▼ ▼ + 完成 OpenImages 865张 → 训 v6 + +支线 ──→ 重新下载 OpenImages Flying disc 865张 + (重做 ID 收集 → 并发下载 → 转 YOLO 格式 → 记录 wiki) +``` + +## Data Inventory (Verified) + +| Source | Images | Type | +|--------|:------:|------| +| frisbee_ultimateml | 1001 | positive | +| frisbee_kaggle | 851 | positive | +| frisbee_coco (COCO frisbee) | 2713 | positive | +| frisbee_negatives | 80 | negative | +| frisbee_pseudo | 103 | mixed | +| frisbee_coco_neg | 445 | negative | +| frisbee_hard_neg | 202 | negative | +| **Total merged** | **4950** | — | +| **Negative ratio** | **16.6%** | ✅ already in range | + +## Training Parameters + +| Parameter | v4 | v5 | Change | +|-----------|:--:|:--:|--------| +| `cls` | 0.5 (default) | **1.3** | ✅ key change | +| `box` | 5 | 5 | same | +| `close_mosaic` | 10 | 10 | same | +| `epochs` | 100 | 100 | same | +| `batch` | 2 | 2 | same | +| `imgsz` | 1280 | 1280 | same | +| `patience` | 20 | 20 | same | +| `optimizer` | AdamW | AdamW | same | +| `data` | `configs/frisbee_merged.yaml` | same | same | + +**Prerequisite:** `train.py` must add `--cls` argparse parameter (currently missing). + +## Success Criteria + +| Metric | v3/v4 | v5 Target | Measurement | +|--------|:-----:|:---------:|-------------| +| Frame detection rate | 79.5% | 60-70% | Script auto | +| Dets per frame | 2.5 | 1.0-1.3 | Script auto | +| FP rate | ~60% | **<15%** | Manual spot-check 100 frames | +| Test video 1 | 55-56min | 55-56min | `predict_video.py` | +| Test video 2 | — | 20-23min | `predict_video.py` | + +FP rate spot-check method: sample 50 frames from each test video (100 total), +manually verify each detection — frisbee vs non-frisbee. + +## Steps + +### Step 1: Add `--cls` to `models/train.py` + +Add `--cls` argparse argument with default 0.5, wire into `train_frisbee_detector()`. + +### Step 2: Train v5 + +```bash +python3 models/train.py \ + --data configs/frisbee_merged.yaml \ + --model-size s \ + --name frisbee_det_s_v5 \ + --epochs 100 --imgsz 1280 --batch 2 \ + --patience 20 --box 5 --cls 1.3 --close-mosaic 10 +``` + +Run in tmux (~1.5-2.5h). Handle double-nesting bug (move model after training). + +### Step 3: Validate + +```bash +python3 models/train.py --validate-only \ + --model-path runs/detect/frisbee_det_s_v5/weights/best.pt \ + --data configs/frisbee_merged.yaml +``` + +### Step 4: Evaluate on test videos + +```bash +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v5/weights/best.pt \ + --video movie/25866279684-1-192_55-56min.mp4 --conf 0.20 + +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v5/weights/best.pt \ + --video movie/clip_20-23min.mp4 --conf 0.20 +``` + +### Step 5: Decision + +- **PASS** (dets/frame < 1.3, FP < 15%): Done. +- **FAIL**: Proceed to OpenImages download → train v6. + +## OpenImages Download (Backup Plan) + +1. Collect all Flying disc image IDs from existing text files in `openimages_frisbee/` +2. Download via CVDF mirror (not Google Storage — blocked) +3. Parse bbox from `oidv6-train-annotations-bbox.csv`, convert to YOLO format +4. Manual QC 20 images, record download summary to wiki +5. Add `frisbee_openimages` as new source in `merge_datasets.py` SOURCES list +6. Re-merge → train v6 with same params + +**Risk:** CVDF mirror may also return 403. If so, terminate this path. + +## Risks + +| Risk | Mitigation | +|------|-----------| +| cls=1.3 alone insufficient | OpenImages backup (v6) | +| v5 recall drops too far | v5 plan accepts tradeoff; frame detection 60-70% is target | +| v4 early-stopped at 31 epochs | v5 may also stop early — accept and evaluate | +| OpenImages download blocked | Terminate gracefully, record failure to wiki | + +## References + +- [Ultralytics Hyperparameter Tuning](https://docs.ultralytics.com/guides/hyperparameter-tuning/) — cls optimal 1.33 +- [Ultralytics #3466](https://github.com/ultralytics/ultralytics/issues/3466) — hard negative mining +- [Ultralytics #10207](https://github.com/ultralytics/ultralytics/issues/10207) — cls loss weight for class imbalance +- `.sisyphus/plans/reduce-false-positives-v5.md` — prior FP mitigation plan diff --git a/docs/superpowers/specs/2026-05-14-v7-data-driven-precision-design.md b/docs/superpowers/specs/2026-05-14-v7-data-driven-precision-design.md new file mode 100644 index 0000000..76d6ab2 --- /dev/null +++ b/docs/superpowers/specs/2026-05-14-v7-data-driven-precision-design.md @@ -0,0 +1,184 @@ +# v7 Data-Driven Precision Improvement — Design Spec + +> **Goal:** Reduce per-box FP rate from 89% to <20% while keeping frame detection rate >=50%. +> **Method:** Replace background images with bbox-level hard negatives (178 FP crops) + add TP frames with YOLO labels (22 frames). +> **Approach:** Quick validation (imgsz=640, 30 epochs) first, then production training (imgsz=1280) if metrics pass. + +## Problem Statement + +v3 at conf=0.35 achieves 69.9% frame detection rate on 55-56min test video but per-box +FP rate is 89% (178/200 boxes reviewed were false positives). TP and FP confidence +distributions fully overlap (mean ~0.52 for both) — threshold filtering is impossible. + +Previous attempts to fix this via `cls` loss weight tuning all failed: +- cls=0.5 (v3): 79.5% detection, ~60%+ FP +- cls=0.6 (v6b): 24.0% detection +- cls=0.8 (v6): 8.9% detection +- cls=1.3 (v5): 1.4% detection + +The model's internal classifier cannot distinguish frisbees from white hats/rocks +because it has never been shown what those FP objects look like as negatives. + +## Root Cause + +The training dataset has 283 generic background images (empty fields, random scenes) +but zero examples of "things that look like frisbees but aren't." The model's +classifier has never been penalized for confusing a white hat with a frisbee. + +## Solution: Data-Driven Approach (Scheme B) + +### Data Changes + +| Change | Count | Source | Safety | +|--------|:-----:|--------|--------| +| Remove 283 generic backgrounds | 283 | frisbee_negatives, coco_neg subsets | Safe — low-information images | +| Add 178 FP crops as backgrounds | 178 | `data/perbox_crops/` (reviewed FP) | Safe — bbox-level crops, not full frames | +| Add 22 TP frames with YOLO labels | 22 | 55-56min test video frames + v3 predictions | **Test scene leak accepted** | +| Keep 1498 existing positive images | 1498 | ultimateml, kaggle, pseudo | Unchanged | + +**Net dataset:** 1498 positive + 22 TP frames + 178 FP crops = 1698 training images +Background ratio: 200/1698 = **11.8%** (target was 10-20%) + +### Why Scheme B (Background Replacement) + +Original dataset had 4950 images (incl. 2179 corrupt COCO). After cleaning: 1781 images +with 283 backgrounds (15.9%). Replacing 283 generic backgrounds with 178 targeted hard +negatives gives the model much stronger negative signal without increasing dataset size. + +### Why 22 TP Frames Are Safe Despite Test Scene Leak + +The 22 TP frames come from the 55-56min test video. This is a data leak — the model +will see test scenes during training. However: + +1. Only 22 frames (1.3% of dataset) — minimal memorization risk +2. v3 already trained on pseudo-labeled frames from similar video sources +3. The primary metric is **FP rate reduction**, not test video mAP +4. 20-23min test video provides independent evaluation (different game, different venue) +5. Quick validation at imgsz=640 will show if the approach works before committing + +If quick validation succeeds, we can collect fresh TP frames from non-test sources +for the production training. + +## Training Configuration + +### Phase 1: Quick Validation (~10-15 min) + +| Parameter | Value | Rationale | +|-----------|:-----:|-----------| +| Base model | v3 (`frisbee_det_s_v3/weights/best.pt`) | Fine-tune from best existing | +| imgsz | 640 | 2x faster than 1280 | +| epochs | 30 | Enough to see trend | +| patience | 10 | Early stop if needed | +| batch | 2 | RTX 5080 16GB limit | +| box | 5 | Same as v3 | +| cls | 0.5 | Same as v3 — don't change | +| workers | 2 | Stable on this hardware | +| name | `frisbee_det_s_v7_quick` | | + +### Phase 2: Production Training (~1-2h, if Phase 1 passes) + +Same as Phase 1 but `imgsz=1280`, `epochs=100`, `patience=20`. +Name: `frisbee_det_s_v7`. + +### Decision Gate + +After Phase 1, evaluate on both test videos: + +| Metric | Threshold | Action | +|--------|:---------:|--------| +| FP rate | <20% | PASS → Phase 2 | +| Frame detection rate (55-56min) | >=50% | PASS → Phase 2 | +| Frame detection rate (20-23min) | >=50% | PASS → Phase 2 | +| Any metric fails | — | STOP — analyze, consider VLM expansion | + +If Phase 1 fails, do NOT proceed to Phase 2. Instead: +1. Analyze failure mode (FP still high? recall collapsed?) +2. Consider VLM-expanded hard negatives (use `collect_hard_negatives.py`) +3. Consider adding more TP frames from non-test videos + +## Preprocessing Script: `tools/prep_v7_data.py` + +### What It Does + +1. **Remove 283 background images** from current training set + - Identify by: empty label files (0 bytes) in `data/datasets/frisbee_merged/labels/train/` + - Remove both image and label file + +2. **Copy 178 FP crops** to training set + - Source: `data/perbox_crops/` files marked FP in `review_results.csv` + - Destination: `data/datasets/frisbee_merged/images/train/` + - Labels: empty `.txt` files (background/negative) + - Prefix: `hardneg_v7_` to distinguish from existing hard negatives + +3. **Extract 22 TP frames** from 55-56min test video + - Read frame numbers from TP entries in `review_results.csv` + - Open video, seek to frame, save as JPEG + - Generate YOLO label by re-running v3 inference on that frame + - Save image as `tp_v7_{frame_num:05d}.jpg` with corresponding label + - Destination: `data/datasets/frisbee_merged/images/train/` + +4. **Regenerate YAML config** via `utils.dataset.write_yaml_config()` + +### Safety Checks + +- Verify source video exists before extracting frames +- Verify v3 model exists for label generation +- Count final dataset and report positive/negative ratio +- Dry-run mode (`--dry-run`) to preview changes without modifying files + +## Evaluation Protocol + +After training: + +```bash +# 55-56min test video +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v7_quick/weights/best.pt \ + --video movie/25866279684-1-192_55-56min.mp4 --conf 0.35 + +# 20-23min test video +python3 inference/predict_video.py \ + --model runs/detect/frisbee_det_s_v7_quick/weights/best.pt \ + --video movie/clip_20-23min.mp4 --conf 0.35 +``` + +Metrics to record: +- Frame detection rate (target: >=50%) +- Avg detections per frame (target: <1.0 — lower than v3's 1.86) +- Confidence distribution stats +- Manual FP spot-check on 50 random detection frames + +## Risks + +| Risk | Impact | Mitigation | +|------|--------|------------| +| FP crops too small (151-200px) for 640/1280 training | Model can't learn from tiny images | Images are padded/resized — YOLO will scale them | +| 22 TP frames cause overfitting to test scenes | Inflated metrics on 55-56min | Cross-check on 20-23min (different game) | +| Quick validation at imgsz=640 misleads | Wrong go/no-go decision | If borderline, re-run at imgsz=1280 before deciding | +| Background removal hurts generalization | Worse on new scenes | 178 FP crops are more informative than 283 generic backgrounds | +| Fine-tuning from v3 forgets existing features | Catastrophic forgetting | 30 epochs with low LR from pretrained weights is standard | + +## File Inventory + +| File | Role | +|------|------| +| `data/perbox_crops/review_results.csv` | 200 reviewed detections (22 TP, 178 FP) | +| `data/perbox_crops/crop_*.jpg` | Bbox-level crop images | +| `movie/25866279684-1-192_55-56min.mp4` | Test video (source of TP frames) | +| `runs/detect/frisbee_det_s_v3/weights/best.pt` | Base model for fine-tuning + TP label generation | +| `tools/prep_v7_data.py` | **NEW** — data preprocessing script | +| `models/train.py` | Training script (add `--cls` parameter if needed) | +| `inference/predict_video.py` | Evaluation script | + +## Expected Outcome + +v3's classifier sees ~1500 frisbee positives and ~300 generic backgrounds. It has +no examples of "frisbee-like objects that aren't frisbees." By replacing those 300 +generic backgrounds with 178 actual FP crops (white hats, rocks, lines) that the +model currently misclassifies, the classifier will learn to distinguish frisbees +from these specific confusers. + +The 22 TP frames ensure the model retains the ability to detect real frisbees in +game footage — preventing the recall collapse seen with cls tuning. + +Expected: FP rate drops from 89% to <20%, frame detection rate stays above 50%. diff --git a/docs/superpowers/specs/2026-05-18-homography-calibration-design.md b/docs/superpowers/specs/2026-05-18-homography-calibration-design.md new file mode 100644 index 0000000..fd277f7 --- /dev/null +++ b/docs/superpowers/specs/2026-05-18-homography-calibration-design.md @@ -0,0 +1,142 @@ +# Homography 场地标定工具设计 + +> **目标:** 创建 Streamlit 工具,在飞盘比赛视频帧上手动标定场地控制点,计算 Homography 变换矩阵,将像素坐标映射到真实场地坐标(100m × 37m),为后续追踪 + 物理过滤 pipeline 提供基础。 +> +> **背景:** 行业方案采用 检测→追踪→轨迹验证 策略,物理过滤需要场地坐标。固定机位 + 清晰标线已验证可行(14帧 GLM-4V-Flash 评分 7-8/10)。 + +## 硬件与约束 + +- **场地尺寸:** 标准飞盘场地 100m × 37m +- **视频源:** 1280×720, 25fps, 固定机位,全程未切换 +- **标定频率:** 只需一次(固定机位,整段视频共用同一个 Homography 矩阵)。**2026-05-15 通过 GLM-4V-Flash 验证(14帧 评分 7-8/10),确认可行。** +- **坐标系:** 左下角 (0, 0) 为原点,X 向右(长边 100m),Y 向上(短边 37m) + +## 组件设计 + +### 1. 标定 UI(tools/calibrate_field.py) + +Streamlit 页面功能: + +- `streamlit run tools/calibrate_field.py -- --video ` 启动 +- 自动加载视频,提取第一帧作为标定基准帧 +- 用户在图像上点击 4-8 个控制点(场地标线交叉点) +- 每个点弹出输入框,填写对应真实坐标 (world_x, world_y) +- 图像上实时标注已选点位置和编号 +- 提供"撤销上一点"和"清空所有点"操作 +- 点选完成后点击"计算 Homography"按钮 +- 计算后显示:矩阵、重投影误差、各点偏差 + +交互流程: + +``` +加载视频 → 显示第一帧 → 用户点击标线交叉点 + → 输入真实坐标 → 实时显示点位 + → 点击"计算 Homography" + → 显示矩阵 + 误差 + 网格线叠加 + 鸟瞰图 + → 点击"保存标定" → 写入 JSON +``` + +### 2. Homography 计算(utils/homography.py) + +``` +函数: + pixel_to_world(matrix, px, py) → (wx, wy) + world_to_pixel(matrix, wx, wy) → (px, py) + compute_homography(points) → matrix, error + draw_field_overlay(image, matrix) → image(网格线叠加) + warp_to_birdseye(image, matrix) → image(鸟瞰图) +``` + +`compute_homography` 实现: +- 输入:`[(px, py, wx, wy), ...]` +- 4 点 → `cv2.getPerspectiveTransform()` 精确解 +- 5+ 点 → `cv2.findHomography(method=cv2.RANSAC, ransacReprojThreshold=3.0)` 过约束 +- 输出:3×3 矩阵 + 均方根重投影误差(单位:像素) + +### 3. 可视化验证 + +**网格线叠加:** +- 在原始帧上绘制标准场地线(四条边线 + 两条得分线 + 中线) +- 每条线均匀采样 100 个世界坐标点,用 `world_to_pixel` 投影到图像 +- 用绿色绘制,半透明覆盖,与实际标线对比 + +**鸟瞰图:** +- 目标尺寸:1000×370 px(每米 10 像素) +- 使用 `cv2.warpPerspective()` 将整帧 warp 到俯视角度 +- 输出应该呈现长方形场地,边线水平/垂直 + +**误差报告:** +- 每个控制点的像素偏差表格(实际图像点 vs 矩阵反投影点) +- 均方根误差 +- 判断标定质量:误差 < 3px → 优秀,3-8px → 可接受,> 8px → 建议重标 + +### 4. 数据持久化 + +存储目录:`configs/homography/` + +文件命名:`.json` + +结构: +```json +{ + "video": "25866279684-1-192.mp4", + "image_size": [1280, 720], + "field_size_m": [100, 37], + "calibration_frame": 0, + "points": [ + {"pixel": [x, y], "world": [x, y], "error_px": 0.5}, + ... + ], + "matrix": [[h11, h12, h13], [h21, h22, h23], [h31, h32, 1.0]], + "reprojection_error_px": 0.8 +} +``` + +载入时验证 JSON schema:检查必要字段、矩阵形状、非空 point 列表。 + +## 文件结构 + +``` +tools/ + calibrate_field.py # Streamlit 应用(主入口) +utils/ + homography.py # Homography 计算 + 可视化函数 +tests/ + test_homography.py # 单元测试 +configs/homography/ # 标定结果 JSON 存储目录 +``` + +## 错误处理 + +| 场景 | 处理 | +|------|------| +| 不满 4 个点点击计算 | Streamlit 弹窗提示"至少需要 4 个点" | +| 点共线导致矩阵奇异 | catch OpenCV 异常,提示"控制点近似共线,请调整" | +| JSON 文件无法解析 | 提示文件损坏,建议重新标定 | +| 视频路径不存在 | 启动时 Argparse 验证路径,报错退出 | +| 鸟瞰图尺寸不正确 | 固定 1000×370 输出,不依赖输入 | + +## 测试 + +- `tests/test_homography.py` — 3 个测试 + 1. 已知变换的点和矩阵,验证正反变换互逆 + 2. 4 个共线点 → 预期抛出错误 + 3. 5+ 个随机点 RANSAC → 验证误差小于阈值 + +## 后续使用方式 + +```python +from utils.homography import load_homography, pixel_to_world + +matrix = load_homography("configs/homography/25866279684-1-192.json") +# 对追踪轨迹每条检测: +world_x, world_y = pixel_to_world(matrix, bbox_center_x, bbox_center_y) +# 计算速度、加速度、位置连续性 → 物理过滤 +``` + +## 不在此设计中的范围 + +- 追踪器实现(ByteTrack/BoT-SORT) +- 物理过滤逻辑(速度/轨迹连续性检查) +- 自动标线检测(完全手动标定) +- 多机位标定(单机位) diff --git a/docs/superpowers/specs/2026-05-18-tracking-pipeline-design.md b/docs/superpowers/specs/2026-05-18-tracking-pipeline-design.md new file mode 100644 index 0000000..862ca17 --- /dev/null +++ b/docs/superpowers/specs/2026-05-18-tracking-pipeline-design.md @@ -0,0 +1,147 @@ +# 追踪 + 场地坐标映射 Pipeline 设计 + +> **目标:** 在 v3 检测输出上跑 ByteTrack 追踪,用 Homography 将像素坐标映射到飞盘场地真实坐标(100m × 37m),输出标注视频 + 结构化轨迹数据,为后续物理过滤 pipeline 提供基础。 +> +> **前置条件:** Homography 标定工具已完成(标定 JSON 存储在 `configs/homography/`)。固定机位视频,只需标定一次。 + +## 硬件与约束 + +- **视频源:** 1280×720, 25fps, 固定机位 +- **模型:** `runs/detect/frisbee_det_s_v3/weights/best.pt`,conf=0.35 +- **追踪器:** YOLOv8 内置 ByteTrack +- **标定文件:** `configs/homography/.json` +- **输出:** 标注视频 + CSV 轨迹数据 + +## 架构 + +``` +v3 模型 (.pt) ──┐ + ├──→ YOLO.track() → ByteTrack → 每帧检测+track ID +Homography JSON ─┘ ↓ + pixel_to_world() + bbox 底边中点 → (wx, wy) + ↓ + ┌──→ 输出视频(标注框 + ID + 坐标) + │ + └──→ 输出 CSV(frame, track_id, px, py, wx, wy, conf) +``` + +## 组件设计 + +### 1. 追踪引擎(inference/predict_track.py) + +主入口脚本,包含: + +**参数:** +- `--video`:输入视频路径(必填) +- `--model`:模型路径(默认 v3) +- `--calibration`:标定 JSON 路径(可选,自动查找规则见下文) +- `--conf`:置信度阈值(默认 0.35) +- `--output-dir`:输出目录(默认 `runs/track//`) +- `--no-visualize`:是否跳过生成标注视频(默认生成) + +**标定文件自动查找规则:** +1. 如果传了 `--calibration`,直接使用 +2. 否则按优先级尝试: + a. `configs/homography/.json`(精确匹配) + b. `configs/homography/.json`(如 `25866279684-1-192_55-56min` → `25866279684-1-192.json`) +3. 找不到则打印警告 `⚠️ No calibration found, output will be pixel-only`,CSV 中 `wx, wy` 留空,**不退出**。 + +**流程:** +```python +model = YOLO(str(model_path)) +calib = load_calibration(cal_path) +matrix = calib["matrix"] + +results = model.track( + source=str(video_path), + conf=conf_threshold, + tracker="bytetrack.yaml", + persist=True, + stream=True, + verbose=False, +) + +for frame_idx, r in enumerate(results): + if r.boxes is None or r.boxes.id is None: + continue + for box, tid, conf in zip(r.boxes.xyxy, r.boxes.id, r.boxes.conf): + cx = (box[0] + box[2]) / 2 + by = box[3] # bbox bottom + wx, wy = pixel_to_world(matrix, float(cx), float(by)) + # draw + accumulate +``` + +**关键细节:** +- `stream=True`:逐帧产生结果,避免内存爆炸 +- `persist=True`:帧间保持追踪状态 +- bbox 底边中点 → 场地坐标(飞盘与地面接触点) + +### 2. 视频标注渲染 + +同脚本内实现,用 OpenCV 在每帧上绘制: +- 每个检测框(绿色矩形) +- 左上角 track ID 编号(如 `#3`) +- 右下角场地坐标(如 `(45.2, 18.5)m`) +- 帧号和检测总数在画面顶部 + +### 3. CSV 轨迹导出 + +每帧收集数据,在脚本结束时写入 CSV: + +``` +frame, track_id, px, py, wx, wy, conf +0, 1, 640.0, 500.0, 50.0, 0.0, 0.52 +0, 2, 320.0, 200.0, 25.0, 25.0, 0.38 +... +``` + +### 4. 错误处理 + +| 场景 | 处理 | +|------|------| +| 标定 JSON 不存在 | 警告并继续,仅输出像素坐标,CSV 中 `wx, wy` 留空 | +| 某帧追踪无结果 | 跳过该帧,不中断 | +| bbox 坐标越界 | clamp 到图像边界后在标注 | +| CSV 写入失败 | print 错误消息,不覆盖可能已存在的同名 CSV | +| Homography 矩阵空 | 跳过坐标映射,只输出像素坐标 | + +### 5. 使用方式 + +```bash +# 基本用法(自动查找标定文件,弱匹配) +python3 inference/predict_track.py --video movie/25866279684-1-192_55-56min.mp4 + +# 指定标定文件(精确匹配) +python3 inference/predict_track.py \ + --video movie/25866279684-1-192_55-56min.mp4 \ + --calibration configs/homography/25866279684-1-192.json + +# 只输出数据不生成视频 +python3 inference/predict_track.py \ + --video movie/25866279684-1-192_55-56min.mp4 \ + --no-visualize +``` + +## 文件变更 + +| 文件 | 操作 | 职责 | +|------|------|------| +| `inference/predict_track.py` | 新建 | 追踪 pipeline 主入口(约 200 行) | + +只有一个文件。坐标映射复用 `utils/homography.py` 的 `pixel_to_world()` 和 `load_calibration()`。 + +## 不在此设计中的范围 + +- 物理过滤逻辑(速度/加速度/轨迹连续性 —— 下个阶段) +- 多机位标定(单机位) +- 其他追踪器(只用 ByteTrack) +- 自动标定(需手动标定) +- 追踪器的单元测试(追踪逻辑嵌入在 cv2 循环中,适合集成验证而非单元测试) + +## 验证方式 + +1. 在 55-56min 视频上运行 +2. 检查输出视频:框是否跟随飞盘、track ID 是否稳定、场地坐标是否合理 +3. 检查 CSV:字段完整、坐标值在 [0,100]×[0,37] 范围内 +4. 在 20-23min 视频上重复验证 diff --git a/inference/predict_track.py b/inference/predict_track.py new file mode 100644 index 0000000..022e394 --- /dev/null +++ b/inference/predict_track.py @@ -0,0 +1,183 @@ +"""ByteTrack tracking + field coordinate mapping. + +Runs YOLO inference with ByteTrack on a video, maps detections to field +coordinates via homography, outputs annotated video + CSV trajectory data. + +Usage: + python3 inference/predict_track.py --video movie/test.mp4 + python3 inference/predict_track.py --video movie/test.mp4 --no-visualize +""" + +import argparse +import csv +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import cv2 +from ultralytics import YOLO + +from configs.models import V3_MODEL, DEFAULT_CONF +from utils.homography import load_calibration, pixel_to_world + + +def find_calibration(video_path: Path) -> dict | None: + """Auto-detect calibration JSON for a video. Returns None if not found.""" + calib_dir = Path("configs/homography") + if not calib_dir.exists(): + return None + + stem = video_path.stem + candidate = calib_dir / f"{stem}.json" + if candidate.exists(): + return load_calibration(candidate) + + if "_" in stem: + base = stem.split("_", 1)[0] + candidate = calib_dir / f"{base}.json" + if candidate.exists(): + return load_calibration(candidate) + + return None + + +def export_tracks_csv(rows: list[dict], output_path: Path) -> None: + """Write tracking data to CSV.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["frame", "track_id", "px", "py", "wx", "wy", "conf"]) + writer.writeheader() + writer.writerows(rows) + + +def get_args(): + parser = argparse.ArgumentParser(description="ByteTrack tracking with field coordinate mapping") + parser.add_argument("--video", required=True, help="Path to input video") + parser.add_argument("--model", type=str, default=str(V3_MODEL)) + parser.add_argument("--calibration", type=str, default=None, help="Explicit calibration JSON path") + parser.add_argument("--conf", type=float, default=DEFAULT_CONF) + parser.add_argument("--output-dir", type=str, default=None) + parser.add_argument("--no-visualize", action="store_true", default=False, help="Skip video output") + return parser.parse_args() + + +def main(): + args = get_args() + video_path = Path(args.video) + if not video_path.exists(): + print(f"ERROR: Video not found: {video_path}") + sys.exit(1) + + print(f"Video: {video_path.name}") + print(f"Model: {args.model}") + print(f"Conf: {args.conf}") + + calib = None + matrix = None + if args.calibration: + calib = load_calibration(Path(args.calibration)) + matrix = calib["matrix"] + print(f"Calib: {args.calibration}") + else: + calib = find_calibration(video_path) + if calib: + matrix = calib["matrix"] + print("Calib: auto-detected") + else: + print("Calib: NONE — output will be pixel-only") + + output_dir = Path(args.output_dir or f"runs/track/{video_path.stem}") + output_dir.mkdir(parents=True, exist_ok=True) + print(f"Output: {output_dir}") + + model = YOLO(args.model) + print("\nTracking...") + + all_rows: list[dict] = [] + out_video = None + + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + print(f"ERROR: Cannot open video: {video_path}") + sys.exit(1) + fps = cap.get(cv2.CAP_PROP_FPS) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + if not args.no_visualize: + out_video = cv2.VideoWriter( + str(output_dir / f"{video_path.stem}_tracked.mp4"), + cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h), + ) + + results = model.track( + source=str(video_path), + conf=args.conf, + tracker="bytetrack.yaml", + persist=True, + stream=True, + verbose=False, + ) + + frame_idx = 0 + for frame_idx, r in enumerate(results, 1): + orig_frame = r.orig_img if hasattr(r, "orig_img") else None + if orig_frame is None: + continue + + if r.boxes is not None and r.boxes.id is not None: + for box, tid, conf_val in zip(r.boxes.xyxy, r.boxes.id, r.boxes.conf): + px = float((box[0] + box[2]) / 2.0) + py = float(box[3]) + tid_int = int(tid) + + row = {"frame": frame_idx, "track_id": tid_int, "px": round(px, 1), + "py": round(py, 1), "wx": None, "wy": None, "conf": round(float(conf_val), 4)} + + if matrix is not None: + try: + wx, wy = pixel_to_world(matrix, px, py) + row["wx"] = round(wx, 2) + row["wy"] = round(wy, 2) + except Exception: + print(f" WARNING: pixel_to_world failed frame {frame_idx}, track {tid_int}") + + all_rows.append(row) + + if out_video is not None: + x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3]) + cv2.rectangle(orig_frame, (x1, y1), (x2, y2), (0, 255, 0), 2) + tid_label = f"#{tid_int}" + cv2.putText(orig_frame, tid_label, (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) + if row["wx"] is not None: + coord_text = f"({row['wx']},{row['wy']})m" + cv2.putText(orig_frame, coord_text, (x2 - 120, y2 + 15), + cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1) + + if out_video is not None: + info = f"Frame: {frame_idx} Dets: {len(r.boxes) if r.boxes else 0}" + cv2.putText(orig_frame, info, (10, 25), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2) + out_video.write(orig_frame) + + if frame_idx % 100 == 0: + print(f" {frame_idx} frames, {len(all_rows)} detections") + + if out_video is not None: + out_video.release() + + total_frames = frame_idx + csv_path = output_dir / f"{video_path.stem}_tracks.csv" + export_tracks_csv(all_rows, csv_path) + dets_per_sec = len(all_rows) / max(total_frames / fps, 0.001) if fps > 0 and total_frames > 0 else 0 + print(f"\nDone: {total_frames} frames, {len(all_rows)} detections ({dets_per_sec:.1f} dets/s)") + print(f"CSV: {csv_path}") + if out_video is not None: + print(f"Video: {output_dir / f'{video_path.stem}_tracked.mp4'}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_homography.py b/tests/test_homography.py new file mode 100644 index 0000000..54a6235 --- /dev/null +++ b/tests/test_homography.py @@ -0,0 +1,158 @@ +"""Tests for homography calibration utilities.""" + +import os +import sys +import tempfile + +import numpy as np +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.homography import ( + compute_homography, + pixel_to_world, + world_to_pixel, + draw_field_overlay, + warp_to_birdseye, + save_calibration, + load_calibration, +) + + +def test_roundtrip_known_transform(): + pixel_pts = np.array([ + [100, 500], + [1180, 500], + [1180, 50], + [100, 50], + ], dtype=np.float64) + + world_pts = np.array([ + [0, 0], + [100, 0], + [100, 37], + [0, 37], + ], dtype=np.float64) + + points = [(p[0], p[1], w[0], w[1]) for p, w in zip(pixel_pts, world_pts)] + matrix, error = compute_homography(points) + + assert matrix is not None + assert error < 1.0 + + for px, py, wx, wy in points: + result_wx, result_wy = pixel_to_world(matrix, px, py) + assert abs(result_wx - wx) < 0.5, f"wx: {result_wx} != {wx}" + assert abs(result_wy - wy) < 0.5, f"wy: {result_wy} != {wy}" + + result_px, result_py = world_to_pixel(matrix, wx, wy) + assert abs(result_px - px) < 1.0, f"px: {result_px} != {px}" + assert abs(result_py - py) < 1.0, f"py: {result_py} != {py}" + + +def test_collinear_points_return_none(): + collinear = [ + (100, 100, 0, 0), + (200, 200, 50, 18.5), + (300, 300, 100, 37), + (400, 400, 50, 18.5), + ] + matrix, error = compute_homography(collinear) + assert matrix is None or error > 50.0 + + +def test_ransac_five_points(): + pixel_pts = [ + (100, 500), + (640, 500), + (1180, 500), + (1180, 50), + (100, 50), + ] + world_pts = [ + (0, 0), + (50, 0), + (100, 0), + (100, 37), + (0, 37), + ] + points = [(p[0], p[1], w[0], w[1]) for p, w in zip(pixel_pts, world_pts)] + matrix, error = compute_homography(points) + assert matrix is not None + assert error < 2.0 + + for px, py, wx, wy in points: + rwx, rwy = pixel_to_world(matrix, px, py) + assert abs(rwx - wx) < 1.0 + assert abs(rwy - wy) < 1.0 + + +def _make_test_matrix(): + points = [ + (100, 500, 0, 0), + (1180, 500, 100, 0), + (1180, 50, 100, 37), + (100, 50, 0, 37), + ] + matrix, _ = compute_homography(points) + assert matrix is not None + return matrix + + +def test_draw_field_overlay_returns_same_size(): + matrix = _make_test_matrix() + img = np.zeros((720, 1280, 3), dtype=np.uint8) + result = draw_field_overlay(img, matrix) + assert result.shape == img.shape + assert result.dtype == np.uint8 + assert np.any(result > 0) + + +def test_warp_to_birdseye_returns_fixed_size(): + matrix = _make_test_matrix() + img = np.zeros((720, 1280, 3), dtype=np.uint8) + result = warp_to_birdseye(img, matrix) + assert result.shape == (370, 1000, 3) + assert result.dtype == np.uint8 + + +def test_save_load_roundtrip(tmp_path): + matrix = _make_test_matrix() + points = [ + {"pixel": [100, 500], "world": [0, 0]}, + {"pixel": [1180, 500], "world": [100, 0]}, + {"pixel": [1180, 50], "world": [100, 37]}, + {"pixel": [100, 50], "world": [0, 37]}, + ] + save_calibration( + path=tmp_path / "test.json", + video="test.mp4", + image_size=[1280, 720], + field_size_m=[100, 37], + calibration_frame=0, + points=points, + matrix=matrix, + reprojection_error_px=0.5, + ) + + data = load_calibration(tmp_path / "test.json") + assert data["video"] == "test.mp4" + assert data["image_size"] == [1280, 720] + assert data["field_size_m"] == [100, 37] + assert len(data["points"]) == 4 + assert np.allclose(data["matrix"], matrix) + assert abs(data["reprojection_error_px"] - 0.5) < 0.01 + + +def test_load_missing_file_raises(): + with pytest.raises(FileNotFoundError): + load_calibration("/tmp/nonexistent_homography.json") + + +def test_load_invalid_json_raises(tmp_path): + bad_file = tmp_path / "bad.json" + bad_file.write_text("not json") + with pytest.raises(ValueError): + load_calibration(bad_file) + diff --git a/tests/test_tracking.py b/tests/test_tracking.py new file mode 100644 index 0000000..6c05913 --- /dev/null +++ b/tests/test_tracking.py @@ -0,0 +1,74 @@ +"""Tests for tracking pipeline helper functions.""" + +import json +import os +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from inference.predict_track import export_tracks_csv, find_calibration + + +def test_find_calibration_exact_match(tmp_path, monkeypatch): + """Exact stem match should find the calibration file.""" + calib_dir = tmp_path / "configs" / "homography" + calib_dir.mkdir(parents=True) + (calib_dir / "test_video.json").write_text(json.dumps({ + "matrix": [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + "video": "test_video.mp4", "image_size": [1280, 720], + "field_size_m": [100, 37], "calibration_frame": 0, + "points": [{"pixel": [100, 500], "world": [0, 0]}], + "reprojection_error_px": 0.5, + })) + + monkeypatch.chdir(tmp_path) + result = find_calibration(Path("/videos/test_video.mp4")) + assert result is not None + assert result["video"] == "test_video.mp4" + + +def test_find_calibration_underscore_fallback(tmp_path, monkeypatch): + """Video name with underscore should fall back to base name.""" + calib_dir = tmp_path / "configs" / "homography" + calib_dir.mkdir(parents=True) + (calib_dir / "base.json").write_text(json.dumps({ + "matrix": [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + "video": "base.mp4", "image_size": [1280, 720], + "field_size_m": [100, 37], "calibration_frame": 0, + "points": [{"pixel": [100, 500], "world": [0, 0]}], + "reprojection_error_px": 0.5, + })) + + monkeypatch.chdir(tmp_path) + result = find_calibration(Path("/videos/base_55-56min.mp4")) + assert result is not None + assert result["video"] == "base.mp4" + + +def test_find_calibration_not_found(tmp_path, monkeypatch): + """No matching file should return None.""" + calib_dir = tmp_path / "configs" / "homography" + calib_dir.mkdir(parents=True) + + monkeypatch.chdir(tmp_path) + result = find_calibration(Path("/videos/nonexistent.mp4")) + assert result is None + + +def test_export_csv_rows(tmp_path): + """CSV export should produce correctly formatted output.""" + rows = [ + {"frame": 0, "track_id": 1, "px": 640.0, "py": 500.0, "wx": None, "wy": None, "conf": 0.52}, + {"frame": 1, "track_id": 1, "px": 642.0, "py": 498.0, "wx": 50.1, "wy": 1.2, "conf": 0.55}, + ] + csv_path = tmp_path / "tracks.csv" + export_tracks_csv(rows, csv_path) + + lines = csv_path.read_text().strip().split("\n") + assert len(lines) == 3 + assert lines[0] == "frame,track_id,px,py,wx,wy,conf" + assert lines[1].endswith("0.52") + assert "50.1" in lines[2] diff --git a/tools/calibrate_field.py b/tools/calibrate_field.py new file mode 100644 index 0000000..ea93a26 --- /dev/null +++ b/tools/calibrate_field.py @@ -0,0 +1,217 @@ +"""Field calibration tool using Streamlit. + +Click field line intersections on a video frame, enter real-world coordinates, +compute homography, verify with overlay and bird's-eye view. + +Usage: + streamlit run tools/calibrate_field.py -- --video movie/25866279684-1-192.mp4 +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import argparse + +import cv2 +import numpy as np +import streamlit as st + +from utils.homography import ( + compute_homography, + draw_field_overlay, + load_calibration, + pixel_to_world, + save_calibration, + warp_to_birdseye, + world_to_pixel, +) + +st.set_page_config(page_title="Field Calibration", layout="wide") + +FIELD_W, FIELD_H = 100, 37 + + +def get_args(): + parser = argparse.ArgumentParser(description="Field calibration tool") + parser.add_argument("--video", required=True, help="Path to video file") + parser.add_argument( + "--frame", type=int, default=0, + help="Frame index to use for calibration (default: 0)", + ) + args, _ = parser.parse_known_args() + return args + + +@st.cache_data +def extract_frame(video_path: str, frame_idx: int) -> np.ndarray | None: + cap = cv2.VideoCapture(video_path) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + ret, frame = cap.read() + cap.release() + if not ret: + return None + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + +def main(): + args = get_args() + video_path = Path(args.video) + if not video_path.exists(): + st.error(f"Video not found: {video_path}") + st.info("Usage: streamlit run tools/calibrate_field.py -- --video ") + return + + st.title("Field Calibration Tool") + st.caption(f"Video: {video_path.name}") + + frame = extract_frame(str(video_path), args.frame) + if frame is None: + st.error("Failed to extract frame from video") + return + + h, w = frame.shape[:2] + st.caption(f"Frame size: {w}x{h} | Field: {FIELD_W}m x {FIELD_H}m | Origin: bottom-left") + + if "points" not in st.session_state: + st.session_state.points = [] + + col_left, col_right = st.columns([3, 1]) + + with col_left: + st.subheader("Calibration Frame") + annotated = frame.copy() + + for i, pt in enumerate(st.session_state.points): + px, py = int(pt["pixel"][0]), int(pt["pixel"][1]) + cv2.circle(annotated, (px, py), 8, (0, 255, 0), 2) + cv2.putText(annotated, str(i + 1), (px + 10, py - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + + st.image(annotated, use_container_width=True) + + st.subheader("Add Control Point") + input_cols = st.columns([1, 1, 1, 1, 1]) + with input_cols[0]: + px_in = st.number_input("Pixel X", min_value=0, max_value=w, value=w // 2, key="px") + with input_cols[1]: + py_in = st.number_input("Pixel Y", min_value=0, max_value=h, value=h // 2, key="py") + with input_cols[2]: + wx_in = st.number_input("World X (m)", min_value=0.0, max_value=float(FIELD_W), value=0.0, step=1.0, key="wx") + with input_cols[3]: + wy_in = st.number_input("World Y (m)", min_value=0.0, max_value=float(FIELD_H), value=0.0, step=1.0, key="wy") + with input_cols[4]: + st.markdown("
", unsafe_allow_html=True) + if st.button("Add Point", type="primary"): + st.session_state.points.append({ + "pixel": [float(px_in), float(py_in)], + "world": [float(wx_in), float(wy_in)], + }) + st.rerun() + + btn_cols = st.columns([1, 1, 1, 1]) + with btn_cols[0]: + if st.button("Undo Last"): + if st.session_state.points: + st.session_state.points.pop() + st.rerun() + with btn_cols[1]: + if st.button("Clear All"): + st.session_state.points = [] + st.rerun() + with btn_cols[2]: + compute_clicked = st.button("Compute Homography", type="primary") + with btn_cols[3]: + save_clicked = st.button("Save Calibration") + + with col_right: + st.subheader("Points") + if not st.session_state.points: + st.info("No points added yet. Enter pixel/world coordinates and click 'Add Point'.") + else: + for i, pt in enumerate(st.session_state.points): + st.text(f"#{i+1}: px=({pt['pixel'][0]:.0f},{pt['pixel'][1]:.0f}) " + f"→ w=({pt['world'][0]:.1f},{pt['world'][1]:.1f})") + + if compute_clicked: + pts = st.session_state.points + if len(pts) < 4: + st.error("Need at least 4 points to compute homography") + return + + raw_points = [ + (p["pixel"][0], p["pixel"][1], p["world"][0], p["world"][1]) + for p in pts + ] + matrix, error = compute_homography(raw_points) + + if matrix is None: + st.error("Homography computation failed. Points may be nearly collinear.") + return + + st.session_state.matrix = matrix + st.session_state.reproj_error = error + + quality = "Excellent" if error < 3 else ("Acceptable" if error < 8 else "Poor — consider recalibrating") + st.success(f"Homography computed — RMSE: {error:.2f}px ({quality})") + + st.subheader("Homography Matrix") + matrix_display = np.array2string(matrix, precision=4, suppress_small=True) + st.code(matrix_display, language="text") + + overlay = draw_field_overlay(frame, matrix) + birdseye = warp_to_birdseye(frame, matrix) + + vis_cols = st.columns(2) + with vis_cols[0]: + st.subheader("Field Line Overlay") + st.image(overlay, use_container_width=True) + with vis_cols[1]: + st.subheader("Bird's-Eye View") + st.image(birdseye, use_container_width=True) + + st.subheader("Per-Point Errors") + for i, pt in enumerate(pts): + px, py = pt["pixel"][0], pt["pixel"][1] + rwx, rwy = pixel_to_world(matrix, px, py) + err = ((rwx - pt["world"][0]) ** 2 + (rwy - pt["world"][1]) ** 2) ** 0.5 + st.text(f"#{i+1}: projected=({rwx:.2f},{rwy:.2f}) expected=({pt['world'][0]:.1f},{pt['world'][1]:.1f}) err={err:.2f}m") + + if save_clicked: + if "matrix" not in st.session_state: + st.error("Compute homography first") + return + + output_dir = Path("configs/homography") + output_path = output_dir / f"{video_path.stem}.json" + + point_errors = [] + matrix = st.session_state.matrix + for pt in st.session_state.points: + px, py = pt["pixel"][0], pt["pixel"][1] + rwx, rwy = pixel_to_world(matrix, px, py) + err_px = ((rwx - pt["world"][0]) ** 2 + (rwy - pt["world"][1]) ** 2) ** 0.5 + point_errors.append(round(err_px, 4)) + + points_with_errors = [] + for pt, err in zip(st.session_state.points, point_errors): + p = dict(pt) + p["error_px"] = err + points_with_errors.append(p) + + save_calibration( + path=output_path, + video=video_path.name, + image_size=[w, h], + field_size_m=[FIELD_W, FIELD_H], + calibration_frame=args.frame, + points=points_with_errors, + matrix=matrix, + reprojection_error_px=round(st.session_state.reproj_error, 4), + ) + st.success(f"Saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/utils/homography.py b/utils/homography.py new file mode 100644 index 0000000..2cf6665 --- /dev/null +++ b/utils/homography.py @@ -0,0 +1,174 @@ +"""Homography calibration: pixel-to-field coordinate mapping. + +Functions: + compute_homography — compute 3x3 transform from matched points + pixel_to_world — image pixel → field coordinate (meters) + world_to_pixel — field coordinate → image pixel + draw_field_overlay — project standard field lines onto image + warp_to_birdseye — generate top-down bird's-eye view + save_calibration — persist calibration to JSON + load_calibration — load calibration from JSON +""" + +import json +from pathlib import Path + +import cv2 +import numpy as np + + +def compute_homography(points: list[tuple[float, float, float, float]]) -> tuple[np.ndarray | None, float]: + """Compute homography from matched pixel/world points. + + Args: + points: list of (pixel_x, pixel_y, world_x, world_y) tuples. + 4 points → exact solve; 5+ → RANSAC. + + Returns: + (3x3 matrix, reprojection_rmse_px) or (None, inf) on failure. + """ + if len(points) < 4: + return None, float("inf") + + pixel_arr = np.array([[p[0], p[1]] for p in points], dtype=np.float64) + world_arr = np.array([[p[2], p[3]] for p in points], dtype=np.float64) + + if len(points) == 4: + matrix = cv2.getPerspectiveTransform( + pixel_arr.astype(np.float32), + world_arr.astype(np.float32), + ) + else: + matrix, _ = cv2.findHomography( + pixel_arr, world_arr, + method=cv2.RANSAC, + ransacReprojThreshold=3.0, + ) + + if matrix is None: + return None, float("inf") + + try: + inv_matrix = np.linalg.inv(matrix) + except np.linalg.LinAlgError: + return None, float("inf") + reprojected_pixels = cv2.perspectiveTransform(world_arr.reshape(1, -1, 2), inv_matrix).reshape(-1, 2) + errors = np.linalg.norm(reprojected_pixels - pixel_arr, axis=1) + rmse = float(np.sqrt(np.mean(errors ** 2))) + + return matrix, rmse + + +def pixel_to_world(matrix: np.ndarray, px: float, py: float) -> tuple[float, float]: + """Convert pixel coordinates to world coordinates.""" + pt = np.array([[[px, py]]], dtype=np.float64) + result = cv2.perspectiveTransform(pt, matrix).reshape(2) + return float(result[0]), float(result[1]) + + +def world_to_pixel(matrix: np.ndarray, wx: float, wy: float) -> tuple[float, float]: + """Convert world coordinates to pixel coordinates.""" + try: + inv = np.linalg.inv(matrix) + except np.linalg.LinAlgError: + return float("nan"), float("nan") + pt = np.array([[[wx, wy]]], dtype=np.float64) + result = cv2.perspectiveTransform(pt, inv).reshape(2) + return float(result[0]), float(result[1]) + + +FIELD_LINES_WORLD = [ + [(0, 0), (100, 0)], + [(100, 0), (100, 37)], + [(100, 37), (0, 37)], + [(0, 37), (0, 0)], + [(0, 18.5), (100, 18.5)], +] + +BIRDSEYE_SIZE = (1000, 370) +LINE_SAMPLES = 100 + + +def draw_field_overlay(image: np.ndarray, matrix: np.ndarray) -> np.ndarray: + overlay = image.copy() + try: + inv_matrix = np.linalg.inv(matrix) + except np.linalg.LinAlgError: + return overlay + for line in FIELD_LINES_WORLD: + pixels = [] + for i in range(LINE_SAMPLES + 1): + t = i / LINE_SAMPLES + wx = line[0][0] + t * (line[1][0] - line[0][0]) + wy = line[0][1] + t * (line[1][1] - line[0][1]) + px, py = world_to_pixel(matrix, wx, wy) + if not (np.isnan(px) or np.isnan(py)): + pixels.append([int(round(px)), int(round(py))]) + if len(pixels) > 1: + cv2.polylines(overlay, [np.array(pixels)], False, (0, 255, 0), 2, cv2.LINE_AA) + return overlay + + +def warp_to_birdseye(image: np.ndarray, matrix: np.ndarray) -> np.ndarray: + try: + inv_matrix = np.linalg.inv(matrix) + except np.linalg.LinAlgError: + return image + return cv2.warpPerspective(image, inv_matrix, BIRDSEYE_SIZE) + + +_CALIBRATION_SCHEMA_KEYS = { + "video", "image_size", "field_size_m", "calibration_frame", + "points", "matrix", "reprojection_error_px", +} + + +def save_calibration( + path: Path | str, + video: str, + image_size: list[int], + field_size_m: list[float], + calibration_frame: int, + points: list[dict], + matrix: np.ndarray, + reprojection_error_px: float, +) -> None: + """Save calibration data to JSON.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + data = { + "video": video, + "image_size": image_size, + "field_size_m": field_size_m, + "calibration_frame": calibration_frame, + "points": points, + "matrix": matrix.tolist(), + "reprojection_error_px": reprojection_error_px, + } + path.write_text(json.dumps(data, indent=2)) + + +def load_calibration(path: Path | str) -> dict: + """Load and validate calibration data from JSON.""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Calibration file not found: {path}") + + try: + data = json.loads(path.read_text()) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in {path}: {e}") + + missing = _CALIBRATION_SCHEMA_KEYS - set(data.keys()) + if missing: + raise ValueError(f"Missing keys in {path}: {missing}") + + if not data["points"]: + raise ValueError(f"Empty points list in {path}") + + mat = np.array(data["matrix"]) + if mat.shape != (3, 3): + raise ValueError(f"Matrix must be 3x3, got {mat.shape}") + + data["matrix"] = mat + return data