diff --git a/docs/evaluation.md b/docs/evaluation.md
index a85dabf..c358225 100644
--- a/docs/evaluation.md
+++ b/docs/evaluation.md
@@ -41,6 +41,12 @@ print(results_str)
The evaluation returns a dictionary of metrics and a formatted string with per-source breakdowns and averages.
+## TreePolygons train / checkpoint eval
+
+Polygon training and `eval_checkpoint.py` use **`--eval-mode stream`** by default: metrics are updated each batch instead of building full `y_pred` / `y_true` lists (much lower peak RAM on large test splits). Metrics match **`--eval-mode legacy`**, which keeps the old “accumulate everything, then `dataset.eval()`” flow.
+
+For custom scripts, the pattern above (lists + `dataset.eval()`) is unchanged.
+
## Evaluation visualizations
For qualitative debugging, pass **`viz_dir`** and optionally **`viz_n_per_source`** (default `4`) to `eval()`. The library writes PNGs under `viz_dir`, grouped in subfolders by source name, with up to `viz_n_per_source` images per source (in dataloader order).
diff --git a/docs/leaderboard.md b/docs/leaderboard.md
index ea38621..393a857 100644
--- a/docs/leaderboard.md
+++ b/docs/leaderboard.md
@@ -90,6 +90,7 @@ All point sources are used to train and predict all box sources.
### Random
+<<<<<<< HEAD
| Model | Fine-tuned | Avg Mask Accuracy | Mask-Aware Precision | Script |
|---|:---:|---|---|---|
| DeepForest | ✓ | 0.232 | 0.872 | `uv run python training/polygons/train.py --split-scheme random` |
@@ -103,6 +104,21 @@ All point sources are used to train and predict all box sources.
| DeepForest | ✓ | 0.176 | — | `uv run python training/polygons/train.py --split-scheme zeroshot` |
| SAM3 | ✗ | 0.165 | — | `uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme zeroshot --hf-token $HF_TOKEN` |
| DeepForest | ✗ | 0.102 | — | `uv run python docs/examples/baseline_polygons.py --split-scheme zeroshot` |
+=======
+| Model | Fine-tuned | Avg Mask Accuracy | Script |
+|---|:---:|---|---|
+| DeepForest | ✓ | 0.238 | `uv run python training/polygons/train.py --split-scheme random` |
+| SAM3 | ✗ | 0.224 | `uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme random --hf-token $HF_TOKEN` |
+| DeepForest | ✗ | 0.081 | `uv run python docs/examples/baseline_polygons.py --split-scheme random` |
+
+### Zero-shot
+
+| Model | Fine-tuned | Avg Mask Accuracy | Script |
+|---|:---:|---|---|
+| SAM3 | ✗ | 0.180 | `uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme zeroshot --hf-token $HF_TOKEN` |
+| DeepForest | ✓ | 0.176 | `uv run python training/polygons/train.py --split-scheme zeroshot` |
+| DeepForest | ✗ | 0.109 | `uv run python docs/examples/baseline_polygons.py --split-scheme zeroshot` |
+>>>>>>> 5570d40 (run eval branch)
### Cross-geometry
diff --git a/docs/public/NEON_benchmark.png b/docs/public/NEON_benchmark.png
index 10d3623..eb28a36 100644
Binary files a/docs/public/NEON_benchmark.png and b/docs/public/NEON_benchmark.png differ
diff --git a/docs/public/NEON_points.png b/docs/public/NEON_points.png
index 87239bc..6c34789 100644
Binary files a/docs/public/NEON_points.png and b/docs/public/NEON_points.png differ
diff --git a/docs/public/OAM-TCD.png b/docs/public/OAM-TCD.png
index 3a3b93f..d3b2408 100644
Binary files a/docs/public/OAM-TCD.png and b/docs/public/OAM-TCD.png differ
diff --git a/docs/public/OSBS_megaplot_2025.png b/docs/public/OSBS_megaplot_2025.png
index 9039026..8cfef41 100644
Binary files a/docs/public/OSBS_megaplot_2025.png and b/docs/public/OSBS_megaplot_2025.png differ
diff --git a/docs/public/Puliti_and_Astrup_2022.png b/docs/public/Puliti_and_Astrup_2022.png
index ebf8b01..739812b 100644
Binary files a/docs/public/Puliti_and_Astrup_2022.png and b/docs/public/Puliti_and_Astrup_2022.png differ
diff --git a/docs/public/Radogoshi_et_al._2021.png b/docs/public/Radogoshi_et_al._2021.png
index c657d74..48b2c71 100644
Binary files a/docs/public/Radogoshi_et_al._2021.png and b/docs/public/Radogoshi_et_al._2021.png differ
diff --git a/docs/public/Reiersen_et_al._2022.png b/docs/public/Reiersen_et_al._2022.png
index 6819117..a4e29a5 100644
Binary files a/docs/public/Reiersen_et_al._2022.png and b/docs/public/Reiersen_et_al._2022.png differ
diff --git "a/docs/public/Sch\303\274tte_et_al._2025.png" "b/docs/public/Sch\303\274tte_et_al._2025.png"
index 9f4a8ac..5f810a3 100644
Binary files "a/docs/public/Sch\303\274tte_et_al._2025.png" and "b/docs/public/Sch\303\274tte_et_al._2025.png" differ
diff --git a/docs/public/SelvaBox.png b/docs/public/SelvaBox.png
index 1b2b945..28d7133 100644
Binary files a/docs/public/SelvaBox.png and b/docs/public/SelvaBox.png differ
diff --git a/docs/public/Sun_et_al._2022.png b/docs/public/Sun_et_al._2022.png
index 8afa5be..f918316 100644
Binary files a/docs/public/Sun_et_al._2022.png and b/docs/public/Sun_et_al._2022.png differ
diff --git a/docs/public/Troles_et_al._2024.png b/docs/public/Troles_et_al._2024.png
index 0e73ea5..73a832e 100644
Binary files a/docs/public/Troles_et_al._2024.png and b/docs/public/Troles_et_al._2024.png differ
diff --git a/docs/public/Vasquez_et_al._2023.png b/docs/public/Vasquez_et_al._2023.png
index c858e90..f9c10e8 100644
Binary files a/docs/public/Vasquez_et_al._2023.png and b/docs/public/Vasquez_et_al._2023.png differ
diff --git a/docs/public/Vasquez_et_al._2023_-_training.png b/docs/public/Vasquez_et_al._2023_-_training.png
index a427741..0d7e06d 100644
Binary files a/docs/public/Vasquez_et_al._2023_-_training.png and b/docs/public/Vasquez_et_al._2023_-_training.png differ
diff --git a/docs/public/Ventura_et_al._2022.png b/docs/public/Ventura_et_al._2022.png
index 0ecdd18..a0fe1e1 100644
Binary files a/docs/public/Ventura_et_al._2022.png and b/docs/public/Ventura_et_al._2022.png differ
diff --git a/docs/public/Weecology_University_Florida.png b/docs/public/Weecology_University_Florida.png
index 7486a67..25131d9 100644
Binary files a/docs/public/Weecology_University_Florida.png and b/docs/public/Weecology_University_Florida.png differ
diff --git a/docs/public/Weinstein_et_al._2018_unsupervised.png b/docs/public/Weinstein_et_al._2018_unsupervised.png
index f7be81a..932a4a5 100644
Binary files a/docs/public/Weinstein_et_al._2018_unsupervised.png and b/docs/public/Weinstein_et_al._2018_unsupervised.png differ
diff --git a/docs/public/Young_et_al._2025_weak_supervised.png b/docs/public/Young_et_al._2025_weak_supervised.png
index d5389a1..aa4f15a 100644
Binary files a/docs/public/Young_et_al._2025_weak_supervised.png and b/docs/public/Young_et_al._2025_weak_supervised.png differ
diff --git a/docs/public/Zuniga-Gonzalez_et_al._2023.png b/docs/public/Zuniga-Gonzalez_et_al._2023.png
index c66cfb8..26f1061 100644
Binary files a/docs/public/Zuniga-Gonzalez_et_al._2023.png and b/docs/public/Zuniga-Gonzalez_et_al._2023.png differ
diff --git a/existing_models/slurm/eval_deepforest.sbatch b/existing_models/slurm/eval_deepforest.sbatch
index f086333..2f6fbb9 100644
--- a/existing_models/slurm/eval_deepforest.sbatch
+++ b/existing_models/slurm/eval_deepforest.sbatch
@@ -21,19 +21,19 @@ cd $REPO/existing_models/deepforest
echo "=== DeepForest eval: split=$SPLIT ==="
-uv run python eval_boxes.py \
- --root-dir "$ROOT_DIR" \
- --split-scheme "$SPLIT" \
- --batch-size 12 \
- --num-workers 4 \
- --output-dir "$OUT_BASE"
-
-uv run python eval_points.py \
- --root-dir "$ROOT_DIR" \
- --split-scheme "$SPLIT" \
- --batch-size 32 \
- --num-workers 4 \
- --output-dir "$OUT_BASE"
+# uv run python eval_boxes.py \
+# --root-dir "$ROOT_DIR" \
+# --split-scheme "$SPLIT" \
+# --batch-size 12 \
+# --num-workers 4 \
+# --output-dir "$OUT_BASE"
+
+# uv run python eval_points.py \
+# --root-dir "$ROOT_DIR" \
+# --split-scheme "$SPLIT" \
+# --batch-size 32 \
+# --num-workers 4 \
+# --output-dir "$OUT_BASE"
uv run python eval_polygons.py \
--root-dir "$ROOT_DIR" \
diff --git a/existing_models/slurm/eval_sam3.sbatch b/existing_models/slurm/eval_sam3.sbatch
index 0835ace..36f4c58 100644
--- a/existing_models/slurm/eval_sam3.sbatch
+++ b/existing_models/slurm/eval_sam3.sbatch
@@ -21,19 +21,19 @@ cd $REPO/existing_models/sam3
echo "=== SAM3 eval: split=$SPLIT ==="
-uv run python eval_boxes.py \
- --root-dir "$ROOT_DIR" \
- --split-scheme "$SPLIT" \
- --device cuda \
- --batch-size 8 \
- --output-dir "$OUT_BASE"
-
-uv run python eval_points.py \
- --root-dir "$ROOT_DIR" \
- --split-scheme "$SPLIT" \
- --device cuda \
- --batch-size 16 \
- --output-dir "$OUT_BASE"
+# uv run python eval_boxes.py \
+# --root-dir "$ROOT_DIR" \
+# --split-scheme "$SPLIT" \
+# --device cuda \
+# --batch-size 8 \
+# --output-dir "$OUT_BASE"
+
+# uv run python eval_points.py \
+# --root-dir "$ROOT_DIR" \
+# --split-scheme "$SPLIT" \
+# --device cuda \
+# --batch-size 16 \
+# --output-dir "$OUT_BASE"
uv run python eval_polygons.py \
--root-dir "$ROOT_DIR" \
diff --git a/src/milliontrees/datasets/polygon_stream_eval.py b/src/milliontrees/datasets/polygon_stream_eval.py
new file mode 100644
index 0000000..d88fc73
--- /dev/null
+++ b/src/milliontrees/datasets/polygon_stream_eval.py
@@ -0,0 +1,266 @@
+"""Streaming TreePolygons evaluation without accumulating all preds/GT in memory."""
+
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+import torch
+
+from milliontrees.common.eval_visualization import save_eval_visualizations
+from milliontrees.common.metrics.all_metrics import DetectionMAP
+from milliontrees.common.utils import maximum, minimum
+
+
+def _disable_torchmetric_sync(metric: Any) -> None:
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
+ metric._to_sync = False
+
+
+class TreePolygonsStreamingEvalState:
+ """Accumulates metrics batch-wise; results match ``standard_group_eval`` semantics."""
+
+ _EW_KEYS = ("accuracy", "recall", "maskaware_precision", "merge_commission")
+ _MAP_KEY = "mAP"
+
+ def __init__(self, dataset: Any) -> None:
+ self._dataset = dataset
+ self._grouper = dataset._eval_grouper
+ self._n_groups = int(self._grouper.n_groups)
+
+ self._ew: dict[str, dict[str, Any]] = {}
+ for key in self._EW_KEYS:
+ self._ew[key] = {
+ "sum": 0.0,
+ "n": 0,
+ "g_sum": torch.zeros(self._n_groups, dtype=torch.float64),
+ "g_cnt": torch.zeros(self._n_groups, dtype=torch.float64),
+ }
+
+ from torchmetrics.detection import MeanAveragePrecision
+
+ self._map_metric: DetectionMAP = dataset.metrics[self._MAP_KEY]
+ self._map_global = MeanAveragePrecision(
+ iou_type=self._map_metric.iou_type, class_metrics=False)
+ _disable_torchmetric_sync(self._map_global)
+ self._map_per_group = [
+ MeanAveragePrecision(iou_type=self._map_metric.iou_type,
+ class_metrics=False)
+ for _ in range(self._n_groups)
+ ]
+ for m in self._map_per_group:
+ _disable_torchmetric_sync(m)
+
+ def update(
+ self,
+ y_pred: list,
+ y_true: list,
+ metadata: torch.Tensor,
+ ) -> None:
+ if not isinstance(metadata, torch.Tensor):
+ metadata = torch.as_tensor(metadata)
+ g = self._grouper.metadata_to_group(metadata)
+
+ for key in self._EW_KEYS:
+ metric = self._dataset.metrics[key]
+ v = metric._compute_element_wise(y_pred, y_true).float()
+ if v.device != g.device:
+ v = v.to(g.device)
+ st = self._ew[key]
+ st["sum"] += float(v.sum().item())
+ st["n"] += int(v.numel())
+ for gi in range(self._n_groups):
+ mask = g == gi
+ if mask.any():
+ st["g_sum"][gi] += float(v[mask].sum().item())
+ st["g_cnt"][gi] += float(mask.sum().item())
+
+ preds, targets = self._map_metric._format(y_pred, y_true)
+ self._map_global.update(preds, targets)
+
+ for gi in range(self._n_groups):
+ mask = g == gi
+ if not mask.any():
+ continue
+ idx = mask.nonzero(as_tuple=True)[0].tolist()
+ gp = [y_pred[i] for i in idx]
+ gt = [y_true[i] for i in idx]
+ p2, t2 = self._map_metric._format(gp, gt)
+ self._map_per_group[gi].update(p2, t2)
+
+ def _finalize_elementwise(self, key: str, metric: Any) -> tuple[dict, str]:
+ st = self._ew[key]
+ results: dict[str, Any] = {}
+ results_str = ""
+
+ if st["n"] == 0:
+ agg = torch.tensor(0.0)
+ else:
+ agg = torch.tensor(st["sum"] / st["n"])
+ results[metric.agg_metric_field] = float(agg.item())
+ results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"
+
+ group_avgs = torch.full((self._n_groups,),
+ float("nan"),
+ dtype=torch.float64)
+ group_counts = st["g_cnt"]
+ for gi in range(self._n_groups):
+ if st["g_cnt"][gi] > 0:
+ group_avgs[gi] = st["g_sum"][gi] / st["g_cnt"][gi]
+
+ valid = group_counts > 0
+ if valid.any():
+ sub = group_avgs[valid].float()
+ sub = sub[~torch.isnan(sub)]
+ if sub.numel() == 0:
+ worst = torch.tensor(0.0)
+ else:
+ worst = metric.worst(sub)
+ else:
+ worst = torch.tensor(0.0)
+
+ for group_idx in range(self._n_groups):
+ group_str = self._grouper.group_field_str(group_idx)
+ gv = group_avgs[group_idx]
+ results[f"{metric.name}_{group_str}"] = (float(
+ gv.item()) if not torch.isnan(gv) else float("nan"))
+ results[f"count_{group_str}"] = float(st["g_cnt"][group_idx].item())
+ if st["g_cnt"][group_idx] == 0:
+ continue
+ results_str += (
+ f" {self._grouper.group_str(group_idx)} "
+ f"[n = {int(st['g_cnt'][group_idx].item()):6d}]:\t"
+ f"{metric.name} = {float(group_avgs[group_idx].item()):5.3f}\n")
+
+ results[metric.worst_group_metric_field] = float(worst.item())
+ results_str += (
+ f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n"
+ )
+ return results, results_str
+
+ def _finalize_map(self, metric: DetectionMAP) -> tuple[dict, str]:
+ results: dict[str, Any] = {}
+ results_str = ""
+
+ n_any = int(self._ew["accuracy"]["g_cnt"].sum().item())
+ if n_any == 0:
+ results[metric.agg_metric_field] = 0.0
+ results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"
+ for group_idx in range(self._n_groups):
+ group_str = self._grouper.group_field_str(group_idx)
+ results[f"{metric.name}_{group_str}"] = 0.0
+ results[f"count_{group_str}"] = 0.0
+ results[metric.worst_group_metric_field] = 0.0
+ results_str += f"Worst-group {metric.name}: 0.000\n"
+ return results, results_str
+
+ _disable_torchmetric_sync(self._map_global)
+ agg_map = float(self._map_global.compute()["map"].item())
+ results[metric.agg_metric_field] = agg_map
+ results_str += f"Average {metric.name}: {agg_map:.3f}\n"
+
+ group_metrics_list: list[torch.Tensor] = []
+ gcnt = self._ew["accuracy"]["g_cnt"]
+ for group_idx in range(self._n_groups):
+ group_str = self._grouper.group_field_str(group_idx)
+ if gcnt[group_idx] <= 0:
+ gv = torch.tensor(0.0)
+ else:
+ _disable_torchmetric_sync(self._map_per_group[group_idx])
+ gv = self._map_per_group[group_idx].compute()["map"]
+ group_metrics_list.append(gv)
+ results[f"{metric.name}_{group_str}"] = float(gv.item())
+ results[f"count_{group_str}"] = float(gcnt[group_idx].item())
+ if gcnt[group_idx] == 0:
+ continue
+ results_str += (f" {self._grouper.group_str(group_idx)} "
+ f"[n = {int(gcnt[group_idx].item()):6d}]:\t"
+ f"{metric.name} = {float(gv.item()):5.3f}\n")
+
+ stacked = torch.stack([t.float() for t in group_metrics_list])
+ worst = metric.worst(stacked[gcnt > 0])
+ results[metric.worst_group_metric_field] = float(worst.item())
+ results_str += (
+ f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n"
+ )
+ return results, results_str
+
+ def finalize(
+ self,
+ *,
+ viz_dir: str | None = None,
+ viz_y_pred: list | None = None,
+ viz_y_true: list | None = None,
+ viz_metadata: torch.Tensor | None = None,
+ viz_n_per_source: int = 4,
+ ) -> tuple[dict[str, Any], str]:
+ results: dict[str, Any] = {}
+ results_str = ""
+
+ for key in self._EW_KEYS:
+ metric = self._dataset.metrics[key]
+ r, s = self._finalize_elementwise(key, metric)
+ results[key] = r
+ results_str += s
+
+ map_metric: DetectionMAP = self._dataset.metrics[self._MAP_KEY]
+ r, s = self._finalize_map(map_metric)
+ results[self._MAP_KEY] = r
+ results_str += s
+
+ # Match ``TreePolygonsDataset.eval`` / ``TreeBoxesDataset.eval`` (same key prefixes).
+ detection_accs: list[float] = []
+ for k, v in results["accuracy"].items():
+ if k.startswith("detection_acc_source:"):
+ d = k.split(":")[1]
+ count = results["accuracy"].get(f"source:{d}")
+ if count and count > 0:
+ detection_accs.append(float(v))
+ detection_acc_avg_dom = float(np.array(detection_accs).mean())
+ results["detection_acc_avg_dom"] = detection_acc_avg_dom
+ results_str = (
+ f"Average detection_acc across source: {detection_acc_avg_dom:.3f}\n"
+ + results_str)
+
+ from milliontrees.common.utils import format_eval_results
+
+ formatted_results = format_eval_results(results, self._dataset)
+ results_str = formatted_results + "\n" + results_str
+
+ if viz_dir is not None and viz_y_pred and viz_y_true and viz_metadata is not None:
+ paths = save_eval_visualizations(
+ self._dataset,
+ viz_y_pred,
+ viz_y_true,
+ viz_metadata,
+ viz_dir,
+ n_per_source=viz_n_per_source,
+ score_threshold=self._dataset.eval_score_threshold,
+ )
+ results["eval_visualization_paths"] = [str(p) for p in paths]
+
+ return results, results_str
+
+
+def merge_viz_samples(
+ cap: dict[int, int],
+ metadata: torch.Tensor,
+ preds: list,
+ targets: list,
+ *,
+ viz_y_pred: list,
+ viz_y_true: list,
+ viz_rows: list[torch.Tensor],
+ n_per_source: int,
+) -> None:
+ """Append up to ``n_per_source`` samples per source_id for visualization."""
+ if not isinstance(metadata, torch.Tensor):
+ metadata = torch.as_tensor(metadata)
+ for i in range(len(preds)):
+ sid = int(metadata[i, 1].item())
+ if cap.get(sid, 0) >= n_per_source:
+ continue
+ viz_y_pred.append(preds[i])
+ viz_y_true.append(targets[i])
+ viz_rows.append(metadata[i].clone())
+ cap[sid] = cap.get(sid, 0) + 1
diff --git a/tests/test_TreePolygons.py b/tests/test_TreePolygons.py
index d749719..1dbc35e 100644
--- a/tests/test_TreePolygons.py
+++ b/tests/test_TreePolygons.py
@@ -1,7 +1,10 @@
from milliontrees.datasets.TreePolygons import TreePolygonsDataset
+from milliontrees.datasets.polygon_stream_eval import TreePolygonsStreamingEvalState
from milliontrees.common.data_loaders import get_train_loader, get_eval_loader
from milliontrees.common.metrics.all_metrics import MaskAwareMaskPrecision
+import math
+
import torch
import pytest
import numpy as np
@@ -118,6 +121,63 @@ def test_TreePolygons_eval(dataset):
assert "maskaware_precision" in eval_results.keys()
assert "merge_commission" in eval_results.keys()
+
+def test_TreePolygons_eval_stream_matches_legacy(dataset):
+ """Streaming eval must match legacy ``dataset.eval`` on the same predictions."""
+ ds = TreePolygonsDataset(download=False, root_dir=dataset, version="0.0")
+ test_dataset = ds.get_subset("test")
+ test_loader = get_eval_loader("standard", test_dataset, batch_size=2)
+
+ _, _, ref_tgt = test_dataset[0]
+ ref_masks = torch.as_tensor(ref_tgt["y"]).clone()
+ ref_boxes = torch.as_tensor(ref_tgt["bboxes"]).clone()
+ ref_labels = torch.as_tensor(ref_tgt["labels"]).clone()
+
+ all_y_pred = []
+ all_y_true = []
+ state = TreePolygonsStreamingEvalState(ds)
+
+ for metadata, x, y_true in test_loader:
+ batch = [{
+ "y": ref_masks,
+ "bboxes": ref_boxes,
+ "labels": ref_labels,
+ "scores": torch.tensor([0.54] * len(ref_labels)),
+ }]
+ all_y_pred.extend(batch)
+ all_y_true.extend(y_true)
+ state.update(batch, y_true, metadata)
+
+ legacy_results, _ = ds.eval(
+ y_pred=all_y_pred,
+ y_true=all_y_true,
+ metadata=test_dataset.metadata_array,
+ )
+ stream_results, _ = state.finalize()
+
+ for metric_name in ("accuracy", "recall", "maskaware_precision", "merge_commission", "mAP"):
+ lk = legacy_results[metric_name]
+ sk = stream_results[metric_name]
+ for key, lv in lk.items():
+ if key == "eval_visualization_paths":
+ continue
+ sv = sk[key]
+ fv = float(lv)
+ fs = float(sv)
+ if math.isnan(fv):
+ assert math.isnan(fs), f"{metric_name}.{key} legacy=nan stream={sv}"
+ else:
+ assert fs == pytest.approx(fv, rel=1e-5, abs=1e-5), (
+ f"{metric_name}.{key} legacy={lv} stream={sv}"
+ )
+ lr_dom = float(legacy_results["detection_acc_avg_dom"])
+ sr_dom = float(stream_results["detection_acc_avg_dom"])
+ if math.isnan(lr_dom):
+ assert math.isnan(sr_dom)
+ else:
+ assert sr_dom == pytest.approx(lr_dom, rel=1e-5, abs=1e-5)
+
+
def test_TreePolygons_download_url(dataset):
ds = TreePolygonsDataset(download=False, root_dir=dataset, version="0.0")
for version in ds._versions_dict.keys():
diff --git a/training/polygons/eval_checkpoint.py b/training/polygons/eval_checkpoint.py
index 3a11a07..182bfbc 100644
--- a/training/polygons/eval_checkpoint.py
+++ b/training/polygons/eval_checkpoint.py
@@ -28,6 +28,13 @@ def main():
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--viz-dir", type=str, default=None,
help="Directory for per-source prediction overlay PNGs")
+ parser.add_argument(
+ "--eval-mode",
+ type=str,
+ default="stream",
+ choices=["stream", "legacy"],
+ help="stream = low-memory per-batch metrics; legacy = accumulate then dataset.eval()",
+ )
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -46,8 +53,15 @@ def main():
)
test_subset = dataset.get_subset("test")
- results, results_str = evaluate(model, dataset, test_subset, batch_size=args.batch_size,
- device=device, viz_dir=args.viz_dir)
+ results, results_str = evaluate(
+ model,
+ dataset,
+ test_subset,
+ batch_size=args.batch_size,
+ device=device,
+ viz_dir=args.viz_dir,
+ eval_mode=args.eval_mode,
+ )
print(results_str)
if args.output_dir:
diff --git a/training/polygons/train_polygons.py b/training/polygons/train_polygons.py
index 3d7eb45..c4186fd 100644
--- a/training/polygons/train_polygons.py
+++ b/training/polygons/train_polygons.py
@@ -10,8 +10,6 @@
import argparse
import os
-import warnings
-from typing import List
import numpy as np
import pytorch_lightning as pl
@@ -23,6 +21,10 @@
from milliontrees import get_dataset
from milliontrees.common.data_loaders import get_train_loader, get_eval_loader
+from milliontrees.datasets.polygon_stream_eval import (
+ TreePolygonsStreamingEvalState,
+ merge_viz_samples,
+)
def get_mask_rcnn(num_classes=2):
@@ -131,21 +133,74 @@ def format_predictions_for_eval(images, model, device, mask_threshold=0.5):
return batch_y_pred
-def evaluate(model, dataset, test_subset, batch_size=8, device="cuda", viz_dir=None):
+def evaluate(
+ model,
+ dataset,
+ test_subset,
+ batch_size=8,
+ device="cuda",
+ viz_dir=None,
+ *,
+ eval_mode="stream",
+ viz_n_per_source=4,
+):
+ """Run test-set evaluation.
+
+ ``eval_mode``:
+ - ``stream`` (default): update metrics per batch; does not accumulate all
+ masks in Python lists (lower peak memory).
+ - ``legacy``: accumulate full ``y_pred`` / ``y_true`` lists then call
+ ``dataset.eval()`` once (previous behavior).
+ """
test_loader = get_eval_loader("standard", test_subset, batch_size=batch_size)
- all_y_pred, all_y_true = [], []
model.eval()
+
+ if eval_mode == "legacy":
+ all_y_pred, all_y_true = [], []
+ for batch in test_loader:
+ metadata, images, targets = batch
+ preds = format_predictions_for_eval(images, model, device)
+ for y_pred, image_targets in zip(preds, targets):
+ all_y_pred.append(y_pred)
+ all_y_true.append(image_targets)
+ return dataset.eval(
+ all_y_pred,
+ all_y_true,
+ test_subset.metadata_array[: len(all_y_true)],
+ viz_dir=viz_dir,
+ viz_n_per_source=viz_n_per_source,
+ )
+
+ if eval_mode != "stream":
+ raise ValueError(f"Unknown eval_mode: {eval_mode!r}; use 'stream' or 'legacy'.")
+
+ state = TreePolygonsStreamingEvalState(dataset)
+ viz_cap: dict[int, int] = {}
+ viz_y_pred, viz_y_true, viz_rows = [], [], []
for batch in test_loader:
metadata, images, targets = batch
preds = format_predictions_for_eval(images, model, device)
- for y_pred, image_targets in zip(preds, targets):
- all_y_pred.append(y_pred)
- all_y_true.append(image_targets)
- results, results_str = dataset.eval(
- all_y_pred, all_y_true, test_subset.metadata_array[:len(all_y_true)],
+ state.update(preds, targets, metadata)
+ if viz_dir is not None:
+ merge_viz_samples(
+ viz_cap,
+ metadata,
+ preds,
+ targets,
+ viz_y_pred=viz_y_pred,
+ viz_y_true=viz_y_true,
+ viz_rows=viz_rows,
+ n_per_source=viz_n_per_source,
+ )
+
+ viz_meta = torch.stack(viz_rows, dim=0) if viz_rows else None
+ return state.finalize(
viz_dir=viz_dir,
+ viz_y_pred=viz_y_pred or None,
+ viz_y_true=viz_y_true or None,
+ viz_metadata=viz_meta,
+ viz_n_per_source=viz_n_per_source,
)
- return results, results_str
def main():
@@ -176,6 +231,13 @@ def main():
default=50,
help="Number of validation batches per epoch (for fast debugging)",
)
+ parser.add_argument(
+ "--eval-mode",
+ type=str,
+ default="stream",
+ choices=["stream", "legacy"],
+ help="Test eval: 'stream' avoids holding the full test set in memory; 'legacy' matches old behavior.",
+ )
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
@@ -253,7 +315,12 @@ def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
results, results_str = evaluate(
- model, polygon_dataset, test_subset, batch_size=args.batch_size, device=device
+ model,
+ polygon_dataset,
+ test_subset,
+ batch_size=args.batch_size,
+ device=device,
+ eval_mode=args.eval_mode,
)
print(results_str)