Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ print(results_str)

The evaluation returns a dictionary of metrics and a formatted string with per-source breakdowns and averages.

## TreePolygons train / checkpoint eval

Polygon training and `eval_checkpoint.py` use **`--eval-mode stream`** by default: metrics are updated each batch instead of building full `y_pred` / `y_true` lists (much lower peak RAM on large test splits). Metrics match **`--eval-mode legacy`**, which keeps the old “accumulate everything, then `dataset.eval()`” flow.

For custom scripts, the pattern above (lists + `dataset.eval()`) is unchanged.

## Evaluation visualizations

For qualitative debugging, pass **`viz_dir`** and optionally **`viz_n_per_source`** (default `4`) to `eval()`. The library writes PNGs under `viz_dir`, grouped in subfolders by source name, with up to `viz_n_per_source` images per source (in dataloader order).
Expand Down
16 changes: 16 additions & 0 deletions docs/leaderboard.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ All point sources are used to train and predict all box sources.

### Random

<<<<<<< HEAD
| Model | Fine-tuned | Avg Mask Accuracy | Mask-Aware Precision | Script |
|---|:---:|---|---|---|
| DeepForest | ✓ | 0.232 | 0.872 | <small>`uv run python training/polygons/train.py --split-scheme random`</small> |
Expand All @@ -103,6 +104,21 @@ All point sources are used to train and predict all box sources.
| DeepForest | ✓ | 0.176 | — | <small>`uv run python training/polygons/train.py --split-scheme zeroshot`</small> |
| SAM3 | ✗ | 0.165 | — | <small>`uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme zeroshot --hf-token $HF_TOKEN`</small> |
| DeepForest | ✗ | 0.102 | — | <small>`uv run python docs/examples/baseline_polygons.py --split-scheme zeroshot`</small> |
=======
| Model | Fine-tuned | Avg Mask Accuracy | Script |
|---|:---:|---|---|
| DeepForest | ✓ | 0.238 | <small>`uv run python training/polygons/train.py --split-scheme random`</small> |
| SAM3 | ✗ | 0.224 | <small>`uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme random --hf-token $HF_TOKEN`</small> |
| DeepForest | ✗ | 0.081 | <small>`uv run python docs/examples/baseline_polygons.py --split-scheme random`</small> |

### Zero-shot

| Model | Fine-tuned | Avg Mask Accuracy | Script |
|---|:---:|---|---|
| SAM3 | ✗ | 0.180 | <small>`uv run python docs/examples/sam3_polygons.py --device cuda --split-scheme zeroshot --hf-token $HF_TOKEN`</small> |
| DeepForest | ✓ | 0.176 | <small>`uv run python training/polygons/train.py --split-scheme zeroshot`</small> |
| DeepForest | ✗ | 0.109 | <small>`uv run python docs/examples/baseline_polygons.py --split-scheme zeroshot`</small> |
>>>>>>> 5570d40 (run eval branch)

### Cross-geometry

Expand Down
Binary file modified docs/public/NEON_benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/NEON_points.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/OAM-TCD.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/OSBS_megaplot_2025.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Puliti_and_Astrup_2022.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Radogoshi_et_al._2021.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Reiersen_et_al._2022.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Schütte_et_al._2025.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/SelvaBox.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Sun_et_al._2022.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Troles_et_al._2024.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Vasquez_et_al._2023.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Vasquez_et_al._2023_-_training.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Ventura_et_al._2022.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Weecology_University_Florida.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Weinstein_et_al._2018_unsupervised.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Young_et_al._2025_weak_supervised.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/public/Zuniga-Gonzalez_et_al._2023.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 13 additions & 13 deletions existing_models/slurm/eval_deepforest.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ cd $REPO/existing_models/deepforest

echo "=== DeepForest eval: split=$SPLIT ==="

uv run python eval_boxes.py \
--root-dir "$ROOT_DIR" \
--split-scheme "$SPLIT" \
--batch-size 12 \
--num-workers 4 \
--output-dir "$OUT_BASE"

uv run python eval_points.py \
--root-dir "$ROOT_DIR" \
--split-scheme "$SPLIT" \
--batch-size 32 \
--num-workers 4 \
--output-dir "$OUT_BASE"
# uv run python eval_boxes.py \
# --root-dir "$ROOT_DIR" \
# --split-scheme "$SPLIT" \
# --batch-size 12 \
# --num-workers 4 \
# --output-dir "$OUT_BASE"

# uv run python eval_points.py \
# --root-dir "$ROOT_DIR" \
# --split-scheme "$SPLIT" \
# --batch-size 32 \
# --num-workers 4 \
# --output-dir "$OUT_BASE"

uv run python eval_polygons.py \
--root-dir "$ROOT_DIR" \
Expand Down
26 changes: 13 additions & 13 deletions existing_models/slurm/eval_sam3.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ cd $REPO/existing_models/sam3

echo "=== SAM3 eval: split=$SPLIT ==="

uv run python eval_boxes.py \
--root-dir "$ROOT_DIR" \
--split-scheme "$SPLIT" \
--device cuda \
--batch-size 8 \
--output-dir "$OUT_BASE"

uv run python eval_points.py \
--root-dir "$ROOT_DIR" \
--split-scheme "$SPLIT" \
--device cuda \
--batch-size 16 \
--output-dir "$OUT_BASE"
# uv run python eval_boxes.py \
# --root-dir "$ROOT_DIR" \
# --split-scheme "$SPLIT" \
# --device cuda \
# --batch-size 8 \
# --output-dir "$OUT_BASE"

# uv run python eval_points.py \
# --root-dir "$ROOT_DIR" \
# --split-scheme "$SPLIT" \
# --device cuda \
# --batch-size 16 \
# --output-dir "$OUT_BASE"

uv run python eval_polygons.py \
--root-dir "$ROOT_DIR" \
Expand Down
266 changes: 266 additions & 0 deletions src/milliontrees/datasets/polygon_stream_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""Streaming TreePolygons evaluation without accumulating all preds/GT in memory."""

from __future__ import annotations

from typing import Any

import numpy as np
import torch

from milliontrees.common.eval_visualization import save_eval_visualizations
from milliontrees.common.metrics.all_metrics import DetectionMAP
from milliontrees.common.utils import maximum, minimum


def _disable_torchmetric_sync(metric: Any) -> None:
if torch.distributed.is_available() and torch.distributed.is_initialized():
metric._to_sync = False


class TreePolygonsStreamingEvalState:
"""Accumulates metrics batch-wise; results match ``standard_group_eval`` semantics."""

_EW_KEYS = ("accuracy", "recall", "maskaware_precision", "merge_commission")
_MAP_KEY = "mAP"

def __init__(self, dataset: Any) -> None:
self._dataset = dataset
self._grouper = dataset._eval_grouper
self._n_groups = int(self._grouper.n_groups)

self._ew: dict[str, dict[str, Any]] = {}
for key in self._EW_KEYS:
self._ew[key] = {
"sum": 0.0,
"n": 0,
"g_sum": torch.zeros(self._n_groups, dtype=torch.float64),
"g_cnt": torch.zeros(self._n_groups, dtype=torch.float64),
}

from torchmetrics.detection import MeanAveragePrecision

self._map_metric: DetectionMAP = dataset.metrics[self._MAP_KEY]
self._map_global = MeanAveragePrecision(
iou_type=self._map_metric.iou_type, class_metrics=False)
_disable_torchmetric_sync(self._map_global)
self._map_per_group = [
MeanAveragePrecision(iou_type=self._map_metric.iou_type,
class_metrics=False)
for _ in range(self._n_groups)
]
for m in self._map_per_group:
_disable_torchmetric_sync(m)

def update(
self,
y_pred: list,
y_true: list,
metadata: torch.Tensor,
) -> None:
if not isinstance(metadata, torch.Tensor):
metadata = torch.as_tensor(metadata)
g = self._grouper.metadata_to_group(metadata)

for key in self._EW_KEYS:
metric = self._dataset.metrics[key]
v = metric._compute_element_wise(y_pred, y_true).float()
if v.device != g.device:
v = v.to(g.device)
st = self._ew[key]
st["sum"] += float(v.sum().item())
st["n"] += int(v.numel())
for gi in range(self._n_groups):
mask = g == gi
if mask.any():
st["g_sum"][gi] += float(v[mask].sum().item())
st["g_cnt"][gi] += float(mask.sum().item())

preds, targets = self._map_metric._format(y_pred, y_true)
self._map_global.update(preds, targets)

for gi in range(self._n_groups):
mask = g == gi
if not mask.any():
continue
idx = mask.nonzero(as_tuple=True)[0].tolist()
gp = [y_pred[i] for i in idx]
gt = [y_true[i] for i in idx]
p2, t2 = self._map_metric._format(gp, gt)
self._map_per_group[gi].update(p2, t2)

def _finalize_elementwise(self, key: str, metric: Any) -> tuple[dict, str]:
st = self._ew[key]
results: dict[str, Any] = {}
results_str = ""

if st["n"] == 0:
agg = torch.tensor(0.0)
else:
agg = torch.tensor(st["sum"] / st["n"])
results[metric.agg_metric_field] = float(agg.item())
results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"

group_avgs = torch.full((self._n_groups,),
float("nan"),
dtype=torch.float64)
group_counts = st["g_cnt"]
for gi in range(self._n_groups):
if st["g_cnt"][gi] > 0:
group_avgs[gi] = st["g_sum"][gi] / st["g_cnt"][gi]

valid = group_counts > 0
if valid.any():
sub = group_avgs[valid].float()
sub = sub[~torch.isnan(sub)]
if sub.numel() == 0:
worst = torch.tensor(0.0)
else:
worst = metric.worst(sub)
else:
worst = torch.tensor(0.0)

for group_idx in range(self._n_groups):
group_str = self._grouper.group_field_str(group_idx)
gv = group_avgs[group_idx]
results[f"{metric.name}_{group_str}"] = (float(
gv.item()) if not torch.isnan(gv) else float("nan"))
results[f"count_{group_str}"] = float(st["g_cnt"][group_idx].item())
if st["g_cnt"][group_idx] == 0:
continue
results_str += (
f" {self._grouper.group_str(group_idx)} "
f"[n = {int(st['g_cnt'][group_idx].item()):6d}]:\t"
f"{metric.name} = {float(group_avgs[group_idx].item()):5.3f}\n")

results[metric.worst_group_metric_field] = float(worst.item())
results_str += (
f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n"
)
return results, results_str

def _finalize_map(self, metric: DetectionMAP) -> tuple[dict, str]:
results: dict[str, Any] = {}
results_str = ""

n_any = int(self._ew["accuracy"]["g_cnt"].sum().item())
if n_any == 0:
results[metric.agg_metric_field] = 0.0
results_str += f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n"
for group_idx in range(self._n_groups):
group_str = self._grouper.group_field_str(group_idx)
results[f"{metric.name}_{group_str}"] = 0.0
results[f"count_{group_str}"] = 0.0
results[metric.worst_group_metric_field] = 0.0
results_str += f"Worst-group {metric.name}: 0.000\n"
return results, results_str

_disable_torchmetric_sync(self._map_global)
agg_map = float(self._map_global.compute()["map"].item())
results[metric.agg_metric_field] = agg_map
results_str += f"Average {metric.name}: {agg_map:.3f}\n"

group_metrics_list: list[torch.Tensor] = []
gcnt = self._ew["accuracy"]["g_cnt"]
for group_idx in range(self._n_groups):
group_str = self._grouper.group_field_str(group_idx)
if gcnt[group_idx] <= 0:
gv = torch.tensor(0.0)
else:
_disable_torchmetric_sync(self._map_per_group[group_idx])
gv = self._map_per_group[group_idx].compute()["map"]
group_metrics_list.append(gv)
results[f"{metric.name}_{group_str}"] = float(gv.item())
results[f"count_{group_str}"] = float(gcnt[group_idx].item())
if gcnt[group_idx] == 0:
continue
results_str += (f" {self._grouper.group_str(group_idx)} "
f"[n = {int(gcnt[group_idx].item()):6d}]:\t"
f"{metric.name} = {float(gv.item()):5.3f}\n")

stacked = torch.stack([t.float() for t in group_metrics_list])
worst = metric.worst(stacked[gcnt > 0])
results[metric.worst_group_metric_field] = float(worst.item())
results_str += (
f"Worst-group {metric.name}: {results[metric.worst_group_metric_field]:.3f}\n"
)
return results, results_str

def finalize(
self,
*,
viz_dir: str | None = None,
viz_y_pred: list | None = None,
viz_y_true: list | None = None,
viz_metadata: torch.Tensor | None = None,
viz_n_per_source: int = 4,
) -> tuple[dict[str, Any], str]:
results: dict[str, Any] = {}
results_str = ""

for key in self._EW_KEYS:
metric = self._dataset.metrics[key]
r, s = self._finalize_elementwise(key, metric)
results[key] = r
results_str += s

map_metric: DetectionMAP = self._dataset.metrics[self._MAP_KEY]
r, s = self._finalize_map(map_metric)
results[self._MAP_KEY] = r
results_str += s

# Match ``TreePolygonsDataset.eval`` / ``TreeBoxesDataset.eval`` (same key prefixes).
detection_accs: list[float] = []
for k, v in results["accuracy"].items():
if k.startswith("detection_acc_source:"):
d = k.split(":")[1]
count = results["accuracy"].get(f"source:{d}")
if count and count > 0:
detection_accs.append(float(v))
detection_acc_avg_dom = float(np.array(detection_accs).mean())
results["detection_acc_avg_dom"] = detection_acc_avg_dom
results_str = (
f"Average detection_acc across source: {detection_acc_avg_dom:.3f}\n"
+ results_str)

from milliontrees.common.utils import format_eval_results

formatted_results = format_eval_results(results, self._dataset)
results_str = formatted_results + "\n" + results_str

if viz_dir is not None and viz_y_pred and viz_y_true and viz_metadata is not None:
paths = save_eval_visualizations(
self._dataset,
viz_y_pred,
viz_y_true,
viz_metadata,
viz_dir,
n_per_source=viz_n_per_source,
score_threshold=self._dataset.eval_score_threshold,
)
results["eval_visualization_paths"] = [str(p) for p in paths]

return results, results_str


def merge_viz_samples(
cap: dict[int, int],
metadata: torch.Tensor,
preds: list,
targets: list,
*,
viz_y_pred: list,
viz_y_true: list,
viz_rows: list[torch.Tensor],
n_per_source: int,
) -> None:
"""Append up to ``n_per_source`` samples per source_id for visualization."""
if not isinstance(metadata, torch.Tensor):
metadata = torch.as_tensor(metadata)
for i in range(len(preds)):
sid = int(metadata[i, 1].item())
if cap.get(sid, 0) >= n_per_source:
continue
viz_y_pred.append(preds[i])
viz_y_true.append(targets[i])
viz_rows.append(metadata[i].clone())
cap[sid] = cap.get(sid, 0) + 1
Loading
Loading