diff --git a/docs/evaluation.md b/docs/evaluation.md index a85dabf..c358225 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -41,6 +41,12 @@ print(results_str) The evaluation returns a dictionary of metrics and a formatted string with per-source breakdowns and averages. +## TreePolygons train / checkpoint eval + +Polygon training and `eval_checkpoint.py` use **`--eval-mode stream`** by default: metrics are updated each batch instead of building full `y_pred` / `y_true` lists (much lower peak RAM on large test splits). Metrics match **`--eval-mode legacy`**, which keeps the old “accumulate everything, then `dataset.eval()`” flow. + +For custom scripts, the pattern above (lists + `dataset.eval()`) is unchanged. + ## Evaluation visualizations For qualitative debugging, pass **`viz_dir`** and optionally **`viz_n_per_source`** (default `4`) to `eval()`. The library writes PNGs under `viz_dir`, grouped in subfolders by source name, with up to `viz_n_per_source` images per source (in dataloader order). diff --git a/docs/leaderboard.md b/docs/leaderboard.md index ea38621..393a857 100644 --- a/docs/leaderboard.md +++ b/docs/leaderboard.md @@ -90,6 +90,7 @@ All point sources are used to train and predict all box sources. ### Random +<<<<<<< HEAD | Model | Fine-tuned | Avg Mask Accuracy | Mask-Aware Precision | Script | |---|:---:|---|---|---| | DeepForest | ✓ | 0.232 | 0.872 | `uv run python training/polygons/train.py --split-scheme random` | @@ -103,6 +104,21 @@ All point sources are used to train and predict all box sources. | DeepForest | ✓ | 0.176 | — | `uv run python training/polygons/train.py --split-scheme zeroshot` | | SAM3 | ✗ | 0.165 | — | `uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme zeroshot --hf-token $HF_TOKEN` | | DeepForest | ✗ | 0.102 | — | `uv run python docs/examples/baseline_polygons.py --split-scheme zeroshot` | +======= +| Model | Fine-tuned | Avg Mask Accuracy | Script | +|---|:---:|---|---| +| DeepForest | ✓ | 0.238 | `uv run python training/polygons/train.py --split-scheme random` | +| SAM3 | ✗ | 0.224 | `uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme random --hf-token $HF_TOKEN` | +| DeepForest | ✗ | 0.081 | `uv run python docs/examples/baseline_polygons.py --split-scheme random` | + +### Zero-shot + +| Model | Fine-tuned | Avg Mask Accuracy | Script | +|---|:---:|---|---| +| SAM3 | ✗ | 0.180 | `uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme zeroshot --hf-token $HF_TOKEN` | +| DeepForest | ✓ | 0.176 | `uv run python training/polygons/train.py --split-scheme zeroshot` | +| DeepForest | ✗ | 0.109 | `uv run python docs/examples/baseline_polygons.py --split-scheme zeroshot` | +>>>>>>> 5570d40 (run eval branch) ### Cross-geometry diff --git a/docs/public/NEON_benchmark.png b/docs/public/NEON_benchmark.png index 10d3623..eb28a36 100644 Binary files a/docs/public/NEON_benchmark.png and b/docs/public/NEON_benchmark.png differ diff --git a/docs/public/NEON_points.png b/docs/public/NEON_points.png index 87239bc..6c34789 100644 Binary files a/docs/public/NEON_points.png and b/docs/public/NEON_points.png differ diff --git a/docs/public/OAM-TCD.png b/docs/public/OAM-TCD.png index 3a3b93f..d3b2408 100644 Binary files a/docs/public/OAM-TCD.png and b/docs/public/OAM-TCD.png differ diff --git a/docs/public/OSBS_megaplot_2025.png b/docs/public/OSBS_megaplot_2025.png index 9039026..8cfef41 100644 Binary files a/docs/public/OSBS_megaplot_2025.png and b/docs/public/OSBS_megaplot_2025.png differ diff --git a/docs/public/Puliti_and_Astrup_2022.png b/docs/public/Puliti_and_Astrup_2022.png index ebf8b01..739812b 100644 Binary files a/docs/public/Puliti_and_Astrup_2022.png and b/docs/public/Puliti_and_Astrup_2022.png differ diff --git a/docs/public/Radogoshi_et_al._2021.png b/docs/public/Radogoshi_et_al._2021.png index c657d74..48b2c71 100644 Binary files a/docs/public/Radogoshi_et_al._2021.png and b/docs/public/Radogoshi_et_al._2021.png differ diff --git a/docs/public/Reiersen_et_al._2022.png b/docs/public/Reiersen_et_al._2022.png index 6819117..a4e29a5 100644 Binary files a/docs/public/Reiersen_et_al._2022.png and b/docs/public/Reiersen_et_al._2022.png differ diff --git "a/docs/public/Sch\303\274tte_et_al._2025.png" "b/docs/public/Sch\303\274tte_et_al._2025.png" index 9f4a8ac..5f810a3 100644 Binary files "a/docs/public/Sch\303\274tte_et_al._2025.png" and "b/docs/public/Sch\303\274tte_et_al._2025.png" differ diff --git a/docs/public/SelvaBox.png b/docs/public/SelvaBox.png index 1b2b945..28d7133 100644 Binary files a/docs/public/SelvaBox.png and b/docs/public/SelvaBox.png differ diff --git a/docs/public/Sun_et_al._2022.png b/docs/public/Sun_et_al._2022.png index 8afa5be..f918316 100644 Binary files a/docs/public/Sun_et_al._2022.png and b/docs/public/Sun_et_al._2022.png differ diff --git a/docs/public/Troles_et_al._2024.png b/docs/public/Troles_et_al._2024.png index 0e73ea5..73a832e 100644 Binary files a/docs/public/Troles_et_al._2024.png and b/docs/public/Troles_et_al._2024.png differ diff --git a/docs/public/Vasquez_et_al._2023.png b/docs/public/Vasquez_et_al._2023.png index c858e90..f9c10e8 100644 Binary files a/docs/public/Vasquez_et_al._2023.png and b/docs/public/Vasquez_et_al._2023.png differ diff --git a/docs/public/Vasquez_et_al._2023_-_training.png b/docs/public/Vasquez_et_al._2023_-_training.png index a427741..0d7e06d 100644 Binary files a/docs/public/Vasquez_et_al._2023_-_training.png and b/docs/public/Vasquez_et_al._2023_-_training.png differ diff --git a/docs/public/Ventura_et_al._2022.png b/docs/public/Ventura_et_al._2022.png index 0ecdd18..a0fe1e1 100644 Binary files a/docs/public/Ventura_et_al._2022.png and b/docs/public/Ventura_et_al._2022.png differ diff --git a/docs/public/Weecology_University_Florida.png b/docs/public/Weecology_University_Florida.png index 7486a67..25131d9 100644 Binary files a/docs/public/Weecology_University_Florida.png and b/docs/public/Weecology_University_Florida.png differ diff --git a/docs/public/Weinstein_et_al._2018_unsupervised.png b/docs/public/Weinstein_et_al._2018_unsupervised.png index f7be81a..932a4a5 100644 Binary files a/docs/public/Weinstein_et_al._2018_unsupervised.png and b/docs/public/Weinstein_et_al._2018_unsupervised.png differ diff --git a/docs/public/Young_et_al._2025_weak_supervised.png b/docs/public/Young_et_al._2025_weak_supervised.png index d5389a1..aa4f15a 100644 Binary files a/docs/public/Young_et_al._2025_weak_supervised.png and b/docs/public/Young_et_al._2025_weak_supervised.png differ diff --git a/docs/public/Zuniga-Gonzalez_et_al._2023.png b/docs/public/Zuniga-Gonzalez_et_al._2023.png index c66cfb8..26f1061 100644 Binary files a/docs/public/Zuniga-Gonzalez_et_al._2023.png and b/docs/public/Zuniga-Gonzalez_et_al._2023.png differ diff --git a/existing_models/slurm/eval_deepforest.sbatch b/existing_models/slurm/eval_deepforest.sbatch index f086333..2f6fbb9 100644 --- a/existing_models/slurm/eval_deepforest.sbatch +++ b/existing_models/slurm/eval_deepforest.sbatch @@ -21,19 +21,19 @@ cd $REPO/existing_models/deepforest echo "=== DeepForest eval: split=$SPLIT ===" -uv run python eval_boxes.py \ - --root-dir "$ROOT_DIR" \ - --split-scheme "$SPLIT" \ - --batch-size 12 \ - --num-workers 4 \ - --output-dir "$OUT_BASE" - -uv run python eval_points.py \ - --root-dir "$ROOT_DIR" \ - --split-scheme "$SPLIT" \ - --batch-size 32 \ - --num-workers 4 \ - --output-dir "$OUT_BASE" +# uv run python eval_boxes.py \ +# --root-dir "$ROOT_DIR" \ +# --split-scheme "$SPLIT" \ +# --batch-size 12 \ +# --num-workers 4 \ +# --output-dir "$OUT_BASE" + +# uv run python eval_points.py \ +# --root-dir "$ROOT_DIR" \ +# --split-scheme "$SPLIT" \ +# --batch-size 32 \ +# --num-workers 4 \ +# --output-dir "$OUT_BASE" uv run python eval_polygons.py \ --root-dir "$ROOT_DIR" \ diff --git a/existing_models/slurm/eval_sam3.sbatch b/existing_models/slurm/eval_sam3.sbatch index 0835ace..36f4c58 100644 --- a/existing_models/slurm/eval_sam3.sbatch +++ b/existing_models/slurm/eval_sam3.sbatch @@ -21,19 +21,19 @@ cd $REPO/existing_models/sam3 echo "=== SAM3 eval: split=$SPLIT ===" -uv run python eval_boxes.py \ - --root-dir "$ROOT_DIR" \ - --split-scheme "$SPLIT" \ - --device cuda \ - --batch-size 8 \ - --output-dir "$OUT_BASE" - -uv run python eval_points.py \ - --root-dir "$ROOT_DIR" \ - --split-scheme "$SPLIT" \ - --device cuda \ - --batch-size 16 \ - --output-dir "$OUT_BASE" +# uv run python eval_boxes.py \ +# --root-dir "$ROOT_DIR" \ +# --split-scheme "$SPLIT" \ +# --device cuda \ +# --batch-size 8 \ +# --output-dir "$OUT_BASE" + +# uv run python eval_points.py \ +# --root-dir "$ROOT_DIR" \ +# --split-scheme "$SPLIT" \ +# --device cuda \ +# --batch-size 16 \ +# --output-dir "$OUT_BASE" uv run python eval_polygons.py \ --root-dir "$ROOT_DIR" \ diff --git a/src/milliontrees/datasets/polygon_stream_eval.py b/src/milliontrees/datasets/polygon_stream_eval.py new file mode 100644 index 0000000..d88fc73 --- /dev/null +++ b/src/milliontrees/datasets/polygon_stream_eval.py @@ -0,0 +1,266 @@ +"""Streaming TreePolygons evaluation without accumulating all preds/GT in memory.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +from milliontrees.common.eval_visualization import save_eval_visualizations +from milliontrees.common.metrics.all_metrics import DetectionMAP +from milliontrees.common.utils import maximum, minimum + + +def _disable_torchmetric_sync(metric: Any) -> None: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + metric._to_sync = False + + +class TreePolygonsStreamingEvalState: + """Accumulates metrics batch-wise; results match ``standard_group_eval`` semantics.""" + + _EW_KEYS = ("accuracy", "recall", "maskaware_precision", "merge_commission") + _MAP_KEY = "mAP" + + def __init__(self, dataset: Any) -> None: + self._dataset = dataset + self._grouper = dataset._eval_grouper + self._n_groups = int(self._grouper.n_groups) + + self._ew: dict[str, dict[str, Any]] = {} + for key in self._EW_KEYS: + self._ew[key] = { + "sum": 0.0, + "n": 0, + "g_sum": torch.zeros(self._n_groups, dtype=torch.float64), + "g_cnt": torch.zeros(self._n_groups, dtype=torch.float64), + } + + from torchmetrics.detection import MeanAveragePrecision + + self._map_metric: DetectionMAP = dataset.metrics[self._MAP_KEY] + self._map_global = MeanAveragePrecision( + iou_type=self._map_metric.iou_type, class_metrics=False) + _disable_torchmetric_sync(self._map_global) + self._map_per_group = [ + MeanAveragePrecision(iou_type=self._map_metric.iou_type, + class_metrics=False) + for _ in range(self._n_groups) + ] + for m in self._map_per_group: + _disable_torchmetric_sync(m) + + def update( + self, + y_pred: list, + y_true: list, + metadata: torch.Tensor, + ) -> None: + if not isinstance(metadata, torch.Tensor): + metadata = torch.as_tensor(metadata) + g = self._grouper.metadata_to_group(metadata) + + for key in self._EW_KEYS: + metric = self._dataset.metrics[key] + v = metric._compute_element_wise(y_pred, y_true).float() + if v.device != g.device: + v = v.to(g.device) + st = self._ew[key] + st["sum"] += float(v.sum().item()) + st["n"] += int(v.numel()) + for gi in range(self._n_groups): + mask = g == gi + if mask.any(): + st["g_sum"][gi] += float(v[mask].sum().item()) + st["g_cnt"][gi] += float(mask.sum().item()) + + preds, targets = self._map_metric._format(y_pred, y_true) + self._map_global.update(preds, targets) + + for gi in range(self._n_groups): + mask = g == gi + if not mask.any(): + continue + idx = mask.nonzero(as_tuple=True)[0].tolist() + gp = [y_pred[i] for i in idx] + gt = [y_true[i] for i in idx] + p2, t2 = self._map_metric._format(gp, gt) + self._map_per_group[gi].update(p2, t2) + + def _finalize_elementwise(self, key: str, metric: Any) -> tuple[dict, str]: + st = self._ew[key] + results: dict[str, Any] = {} + results_str = "" + + if st["n"] == 0: + agg = torch.tensor(0.0) + else: + agg = torch.tensor(st["sum"] / st["n"]) + results[metric.agg_metric_field] = float(agg.item()) + results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n" + + group_avgs = torch.full((self._n_groups,), + float("nan"), + dtype=torch.float64) + group_counts = st["g_cnt"] + for gi in range(self._n_groups): + if st["g_cnt"][gi] > 0: + group_avgs[gi] = st["g_sum"][gi] / st["g_cnt"][gi] + + valid = group_counts > 0 + if valid.any(): + sub = group_avgs[valid].float() + sub = sub[~torch.isnan(sub)] + if sub.numel() == 0: + worst = torch.tensor(0.0) + else: + worst = metric.worst(sub) + else: + worst = torch.tensor(0.0) + + for group_idx in range(self._n_groups): + group_str = self._grouper.group_field_str(group_idx) + gv = group_avgs[group_idx] + results[f"{metric.name}_{group_str}"] = (float( + gv.item()) if not torch.isnan(gv) else float("nan")) + results[f"count_{group_str}"] = float(st["g_cnt"][group_idx].item()) + if st["g_cnt"][group_idx] == 0: + continue + results_str += ( + f" {self._grouper.group_str(group_idx)} " + f"[n = {int(st['g_cnt'][group_idx].item()):6d}]:\t" + f"{metric.name} = {float(group_avgs[group_idx].item()):5.3f}\n") + + results[metric.worst_group_metric_field] = float(worst.item()) + results_str += ( + f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n" + ) + return results, results_str + + def _finalize_map(self, metric: DetectionMAP) -> tuple[dict, str]: + results: dict[str, Any] = {} + results_str = "" + + n_any = int(self._ew["accuracy"]["g_cnt"].sum().item()) + if n_any == 0: + results[metric.agg_metric_field] = 0.0 + results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n" + for group_idx in range(self._n_groups): + group_str = self._grouper.group_field_str(group_idx) + results[f"{metric.name}_{group_str}"] = 0.0 + results[f"count_{group_str}"] = 0.0 + results[metric.worst_group_metric_field] = 0.0 + results_str += f"Worst-group {metric.name}: 0.000\n" + return results, results_str + + _disable_torchmetric_sync(self._map_global) + agg_map = float(self._map_global.compute()["map"].item()) + results[metric.agg_metric_field] = agg_map + results_str += f"Average {metric.name}: {agg_map:.3f}\n" + + group_metrics_list: list[torch.Tensor] = [] + gcnt = self._ew["accuracy"]["g_cnt"] + for group_idx in range(self._n_groups): + group_str = self._grouper.group_field_str(group_idx) + if gcnt[group_idx] <= 0: + gv = torch.tensor(0.0) + else: + _disable_torchmetric_sync(self._map_per_group[group_idx]) + gv = self._map_per_group[group_idx].compute()["map"] + group_metrics_list.append(gv) + results[f"{metric.name}_{group_str}"] = float(gv.item()) + results[f"count_{group_str}"] = float(gcnt[group_idx].item()) + if gcnt[group_idx] == 0: + continue + results_str += (f" {self._grouper.group_str(group_idx)} " + f"[n = {int(gcnt[group_idx].item()):6d}]:\t" + f"{metric.name} = {float(gv.item()):5.3f}\n") + + stacked = torch.stack([t.float() for t in group_metrics_list]) + worst = metric.worst(stacked[gcnt > 0]) + results[metric.worst_group_metric_field] = float(worst.item()) + results_str += ( + f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n" + ) + return results, results_str + + def finalize( + self, + *, + viz_dir: str | None = None, + viz_y_pred: list | None = None, + viz_y_true: list | None = None, + viz_metadata: torch.Tensor | None = None, + viz_n_per_source: int = 4, + ) -> tuple[dict[str, Any], str]: + results: dict[str, Any] = {} + results_str = "" + + for key in self._EW_KEYS: + metric = self._dataset.metrics[key] + r, s = self._finalize_elementwise(key, metric) + results[key] = r + results_str += s + + map_metric: DetectionMAP = self._dataset.metrics[self._MAP_KEY] + r, s = self._finalize_map(map_metric) + results[self._MAP_KEY] = r + results_str += s + + # Match ``TreePolygonsDataset.eval`` / ``TreeBoxesDataset.eval`` (same key prefixes). + detection_accs: list[float] = [] + for k, v in results["accuracy"].items(): + if k.startswith("detection_acc_source:"): + d = k.split(":")[1] + count = results["accuracy"].get(f"source:{d}") + if count and count > 0: + detection_accs.append(float(v)) + detection_acc_avg_dom = float(np.array(detection_accs).mean()) + results["detection_acc_avg_dom"] = detection_acc_avg_dom + results_str = ( + f"Average detection_acc across source: {detection_acc_avg_dom:.3f}\n" + + results_str) + + from milliontrees.common.utils import format_eval_results + + formatted_results = format_eval_results(results, self._dataset) + results_str = formatted_results + "\n" + results_str + + if viz_dir is not None and viz_y_pred and viz_y_true and viz_metadata is not None: + paths = save_eval_visualizations( + self._dataset, + viz_y_pred, + viz_y_true, + viz_metadata, + viz_dir, + n_per_source=viz_n_per_source, + score_threshold=self._dataset.eval_score_threshold, + ) + results["eval_visualization_paths"] = [str(p) for p in paths] + + return results, results_str + + +def merge_viz_samples( + cap: dict[int, int], + metadata: torch.Tensor, + preds: list, + targets: list, + *, + viz_y_pred: list, + viz_y_true: list, + viz_rows: list[torch.Tensor], + n_per_source: int, +) -> None: + """Append up to ``n_per_source`` samples per source_id for visualization.""" + if not isinstance(metadata, torch.Tensor): + metadata = torch.as_tensor(metadata) + for i in range(len(preds)): + sid = int(metadata[i, 1].item()) + if cap.get(sid, 0) >= n_per_source: + continue + viz_y_pred.append(preds[i]) + viz_y_true.append(targets[i]) + viz_rows.append(metadata[i].clone()) + cap[sid] = cap.get(sid, 0) + 1 diff --git a/tests/test_TreePolygons.py b/tests/test_TreePolygons.py index d749719..1dbc35e 100644 --- a/tests/test_TreePolygons.py +++ b/tests/test_TreePolygons.py @@ -1,7 +1,10 @@ from milliontrees.datasets.TreePolygons import TreePolygonsDataset +from milliontrees.datasets.polygon_stream_eval import TreePolygonsStreamingEvalState from milliontrees.common.data_loaders import get_train_loader, get_eval_loader from milliontrees.common.metrics.all_metrics import MaskAwareMaskPrecision +import math + import torch import pytest import numpy as np @@ -118,6 +121,63 @@ def test_TreePolygons_eval(dataset): assert "maskaware_precision" in eval_results.keys() assert "merge_commission" in eval_results.keys() + +def test_TreePolygons_eval_stream_matches_legacy(dataset): + """Streaming eval must match legacy ``dataset.eval`` on the same predictions.""" + ds = TreePolygonsDataset(download=False, root_dir=dataset, version="0.0") + test_dataset = ds.get_subset("test") + test_loader = get_eval_loader("standard", test_dataset, batch_size=2) + + _, _, ref_tgt = test_dataset[0] + ref_masks = torch.as_tensor(ref_tgt["y"]).clone() + ref_boxes = torch.as_tensor(ref_tgt["bboxes"]).clone() + ref_labels = torch.as_tensor(ref_tgt["labels"]).clone() + + all_y_pred = [] + all_y_true = [] + state = TreePolygonsStreamingEvalState(ds) + + for metadata, x, y_true in test_loader: + batch = [{ + "y": ref_masks, + "bboxes": ref_boxes, + "labels": ref_labels, + "scores": torch.tensor([0.54] * len(ref_labels)), + }] + all_y_pred.extend(batch) + all_y_true.extend(y_true) + state.update(batch, y_true, metadata) + + legacy_results, _ = ds.eval( + y_pred=all_y_pred, + y_true=all_y_true, + metadata=test_dataset.metadata_array, + ) + stream_results, _ = state.finalize() + + for metric_name in ("accuracy", "recall", "maskaware_precision", "merge_commission", "mAP"): + lk = legacy_results[metric_name] + sk = stream_results[metric_name] + for key, lv in lk.items(): + if key == "eval_visualization_paths": + continue + sv = sk[key] + fv = float(lv) + fs = float(sv) + if math.isnan(fv): + assert math.isnan(fs), f"{metric_name}.{key} legacy=nan stream={sv}" + else: + assert fs == pytest.approx(fv, rel=1e-5, abs=1e-5), ( + f"{metric_name}.{key} legacy={lv} stream={sv}" + ) + lr_dom = float(legacy_results["detection_acc_avg_dom"]) + sr_dom = float(stream_results["detection_acc_avg_dom"]) + if math.isnan(lr_dom): + assert math.isnan(sr_dom) + else: + assert sr_dom == pytest.approx(lr_dom, rel=1e-5, abs=1e-5) + + def test_TreePolygons_download_url(dataset): ds = TreePolygonsDataset(download=False, root_dir=dataset, version="0.0") for version in ds._versions_dict.keys(): diff --git a/training/polygons/eval_checkpoint.py b/training/polygons/eval_checkpoint.py index 3a11a07..182bfbc 100644 --- a/training/polygons/eval_checkpoint.py +++ b/training/polygons/eval_checkpoint.py @@ -28,6 +28,13 @@ def main(): parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--viz-dir", type=str, default=None, help="Directory for per-source prediction overlay PNGs") + parser.add_argument( + "--eval-mode", + type=str, + default="stream", + choices=["stream", "legacy"], + help="stream = low-memory per-batch metrics; legacy = accumulate then dataset.eval()", + ) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" @@ -46,8 +53,15 @@ def main(): ) test_subset = dataset.get_subset("test") - results, results_str = evaluate(model, dataset, test_subset, batch_size=args.batch_size, - device=device, viz_dir=args.viz_dir) + results, results_str = evaluate( + model, + dataset, + test_subset, + batch_size=args.batch_size, + device=device, + viz_dir=args.viz_dir, + eval_mode=args.eval_mode, + ) print(results_str) if args.output_dir: diff --git a/training/polygons/train_polygons.py b/training/polygons/train_polygons.py index 3d7eb45..c4186fd 100644 --- a/training/polygons/train_polygons.py +++ b/training/polygons/train_polygons.py @@ -10,8 +10,6 @@ import argparse import os -import warnings -from typing import List import numpy as np import pytorch_lightning as pl @@ -23,6 +21,10 @@ from milliontrees import get_dataset from milliontrees.common.data_loaders import get_train_loader, get_eval_loader +from milliontrees.datasets.polygon_stream_eval import ( + TreePolygonsStreamingEvalState, + merge_viz_samples, +) def get_mask_rcnn(num_classes=2): @@ -131,21 +133,74 @@ def format_predictions_for_eval(images, model, device, mask_threshold=0.5): return batch_y_pred -def evaluate(model, dataset, test_subset, batch_size=8, device="cuda", viz_dir=None): +def evaluate( + model, + dataset, + test_subset, + batch_size=8, + device="cuda", + viz_dir=None, + *, + eval_mode="stream", + viz_n_per_source=4, +): + """Run test-set evaluation. + + ``eval_mode``: + - ``stream`` (default): update metrics per batch; does not accumulate all + masks in Python lists (lower peak memory). + - ``legacy``: accumulate full ``y_pred`` / ``y_true`` lists then call + ``dataset.eval()`` once (previous behavior). + """ test_loader = get_eval_loader("standard", test_subset, batch_size=batch_size) - all_y_pred, all_y_true = [], [] model.eval() + + if eval_mode == "legacy": + all_y_pred, all_y_true = [], [] + for batch in test_loader: + metadata, images, targets = batch + preds = format_predictions_for_eval(images, model, device) + for y_pred, image_targets in zip(preds, targets): + all_y_pred.append(y_pred) + all_y_true.append(image_targets) + return dataset.eval( + all_y_pred, + all_y_true, + test_subset.metadata_array[: len(all_y_true)], + viz_dir=viz_dir, + viz_n_per_source=viz_n_per_source, + ) + + if eval_mode != "stream": + raise ValueError(f"Unknown eval_mode: {eval_mode!r}; use 'stream' or 'legacy'.") + + state = TreePolygonsStreamingEvalState(dataset) + viz_cap: dict[int, int] = {} + viz_y_pred, viz_y_true, viz_rows = [], [], [] for batch in test_loader: metadata, images, targets = batch preds = format_predictions_for_eval(images, model, device) - for y_pred, image_targets in zip(preds, targets): - all_y_pred.append(y_pred) - all_y_true.append(image_targets) - results, results_str = dataset.eval( - all_y_pred, all_y_true, test_subset.metadata_array[:len(all_y_true)], + state.update(preds, targets, metadata) + if viz_dir is not None: + merge_viz_samples( + viz_cap, + metadata, + preds, + targets, + viz_y_pred=viz_y_pred, + viz_y_true=viz_y_true, + viz_rows=viz_rows, + n_per_source=viz_n_per_source, + ) + + viz_meta = torch.stack(viz_rows, dim=0) if viz_rows else None + return state.finalize( viz_dir=viz_dir, + viz_y_pred=viz_y_pred or None, + viz_y_true=viz_y_true or None, + viz_metadata=viz_meta, + viz_n_per_source=viz_n_per_source, ) - return results, results_str def main(): @@ -176,6 +231,13 @@ def main(): default=50, help="Number of validation batches per epoch (for fast debugging)", ) + parser.add_argument( + "--eval-mode", + type=str, + default="stream", + choices=["stream", "legacy"], + help="Test eval: 'stream' avoids holding the full test set in memory; 'legacy' matches old behavior.", + ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) @@ -253,7 +315,12 @@ def main(): device = "cuda" if torch.cuda.is_available() else "cpu" results, results_str = evaluate( - model, polygon_dataset, test_subset, batch_size=args.batch_size, device=device + model, + polygon_dataset, + test_subset, + batch_size=args.batch_size, + device=device, + eval_mode=args.eval_mode, ) print(results_str)