Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions dingo/exec/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
"map_at_10",
]

RAW_API_METRICS_OF_INTEREST = [
f"raw_api_{key}" for key in METRICS_OF_INTEREST if key != "main_score"
]


@Executor.register("retrieval")
class RetrievalExecutor:
Expand Down Expand Up @@ -133,6 +137,12 @@ def execute(self) -> SummaryModel:
model.get_search_traces(),
task_name,
)
task_metrics.update(
self._compute_raw_api_metrics_from_search_traces(
model.get_search_traces(),
task_name,
)
)
all_results[task_name] = task_metrics
except Exception as e:
logger.error(f"Task {task_name!r} failed: {e}", exc_info=True)
Expand All @@ -141,6 +151,12 @@ def execute(self) -> SummaryModel:
task_name,
)
if task_metrics:
task_metrics.update(
self._compute_raw_api_metrics_from_search_traces(
model.get_search_traces(),
task_name,
)
)
logger.warning(
"Using search trace fallback metrics for failed task %r",
task_name,
Expand Down Expand Up @@ -320,6 +336,37 @@ def _compute_metrics_from_search_traces(
if values
}

@staticmethod
def _compute_raw_api_metrics_from_search_traces(
traces: list[dict[str, Any]],
task_name: str,
) -> dict[str, float]:
metric_values: dict[str, list[float]] = {}
api_results_counts: list[float] = []
for trace in traces:
if trace.get("task") != task_name:
continue
for query in trace.get("queries", []):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using trace.get("queries", []) can raise a TypeError: 'NoneType' object is not iterable if the "queries" key is explicitly set to None in the trace dictionary. Using trace.get("queries") or [] is more robust and consistent with the defensive pattern used on line 350 (query.get("raw_api_metrics") or {}).

Suggested change
for query in trace.get("queries", []):
for query in (trace.get("queries") or []):

raw_metrics = query.get("raw_api_metrics") or {}
for key in RAW_API_METRICS_OF_INTEREST:
if key not in raw_metrics:
continue
metric_values.setdefault(key, []).append(raw_metrics[key])
if "api_results_count" in query:
api_results_counts.append(float(query.get("api_results_count") or 0))

summary = {
key: round(sum(values) / len(values), 5)
for key, values in metric_values.items()
if values
}
if api_results_counts:
summary["raw_api_avg_results_count"] = round(
sum(api_results_counts) / len(api_results_counts),
5,
)
return summary

def load_data(self):
pass

Expand Down
28 changes: 28 additions & 0 deletions dingo/retrieval/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,34 @@ def compute_query_metrics(
}


def build_raw_api_ranked_doc_ids(
top_api_results: list[dict[str, Any]],
) -> list[str]:
"""Build a ranked doc-id list that preserves raw API ranks for metrics.

Unmatched hits and duplicate resolved corpus IDs are kept as unique
non-relevant placeholders so later API hits do not move upward.
"""
ranked: list[str] = []
seen_resolved_ids: set[str] = set()
for index, result in enumerate(top_api_results, start=1):
rank = result.get("rank") or index
resolved_id = str(result.get("resolved_corpus_id") or "").strip()
if not resolved_id:
ranked.append(f"__raw_api_unmatched__{rank}")
continue
if resolved_id in seen_resolved_ids:
ranked.append(f"__raw_api_duplicate__{rank}")
continue
seen_resolved_ids.add(resolved_id)
ranked.append(resolved_id)
Comment on lines +142 to +154

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In build_raw_api_ranked_doc_ids, the unique suffixes for unmatched and duplicate placeholders are constructed using rank. However, if the input top_api_results contains duplicate or zero ranks, these placeholders might not be unique, which could cause them to collapse and shift subsequent results incorrectly during metric evaluation. Using the loop index (which is guaranteed to be strictly increasing and unique) instead of rank is much more robust and also allows simplifying the function by removing the unused rank variable entirely.

Suggested change
ranked: list[str] = []
seen_resolved_ids: set[str] = set()
for index, result in enumerate(top_api_results, start=1):
rank = result.get("rank") or index
resolved_id = str(result.get("resolved_corpus_id") or "").strip()
if not resolved_id:
ranked.append(f"__raw_api_unmatched__{rank}")
continue
if resolved_id in seen_resolved_ids:
ranked.append(f"__raw_api_duplicate__{rank}")
continue
seen_resolved_ids.add(resolved_id)
ranked.append(resolved_id)
ranked: list[str] = []
seen_resolved_ids: set[str] = set()
for index, result in enumerate(top_api_results, start=1):
resolved_id = str(result.get("resolved_corpus_id") or "").strip()
if not resolved_id:
ranked.append(f"__raw_api_unmatched__{index}")
continue
if resolved_id in seen_resolved_ids:
ranked.append(f"__raw_api_duplicate__{index}")
continue
seen_resolved_ids.add(resolved_id)
ranked.append(resolved_id)
return ranked

return ranked


def prefix_metrics(metrics: dict[str, Any], prefix: str) -> dict[str, Any]:
return {f"{prefix}{key}": value for key, value in metrics.items()}


def resolve_hit(
hit: dict[str, Any],
title_index: dict[str, list[str]],
Expand Down
43 changes: 41 additions & 2 deletions dingo/retrieval/mteb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from mteb.models.model_meta import ModelMeta
from tqdm import tqdm

from dingo.retrieval.eval_utils import normalize_title, resolve_hit
from dingo.retrieval.eval_utils import build_raw_api_ranked_doc_ids, compute_query_metrics, normalize_title, prefix_metrics, resolve_hit
from dingo.retrieval.search_client import SearchClient, SearchResponse

if TYPE_CHECKING:
Expand Down Expand Up @@ -335,6 +335,7 @@ def _process_query(idx_qid_text):
None,
None,
None,
None,
)

if response.error:
Expand All @@ -349,6 +350,7 @@ def _process_query(idx_qid_text):
None,
None,
None,
None,
)

doc_scores: dict[str, float] = {}
Expand Down Expand Up @@ -409,6 +411,17 @@ def _process_query(idx_qid_text):
if relevant_doc_ids is not None
else None
)
raw_api_metrics = (
prefix_metrics(
compute_query_metrics(
build_raw_api_ranked_doc_ids(top_api_results),
relevant_doc_ids,
),
"raw_api_",
)
if relevant_doc_ids is not None
else None
)

return (
idx,
Expand All @@ -421,6 +434,7 @@ def _process_query(idx_qid_text):
top_api_results,
mapping_stats,
relevant_matched_count,
raw_api_metrics,
)

items = [
Expand Down Expand Up @@ -451,9 +465,23 @@ def _process_query(idx_qid_text):
top_api_results,
mapping_stats,
relevant_matched_count,
raw_api_metrics,
) = future.result()

if doc_scores is None:
relevant_doc_ids = (
relevant_docs_by_qid.get(str(qid))
if relevant_docs_by_qid is not None
else None
)
raw_api_metrics = (
prefix_metrics(
compute_query_metrics([], relevant_doc_ids),
"raw_api_",
)
if relevant_doc_ids is not None
else None
)
errors += 1
logger.warning(
f"[{idx + 1}/{total}] {qid} ERROR: {response.error}"
Expand All @@ -475,7 +503,17 @@ def _process_query(idx_qid_text):
"matched_count": 0,
"mapped_count": 0,
"relevant_matched_count": 0,
"relevant_total": 0,
"relevant_total": (
len(relevant_doc_ids)
if relevant_doc_ids is not None
else 0
),
"gold_doc_ids": (
sorted(relevant_doc_ids)
if relevant_doc_ids is not None
else None
),
"raw_api_metrics": raw_api_metrics,
}
)
else:
Expand Down Expand Up @@ -520,6 +558,7 @@ def _process_query(idx_qid_text):
"top_api_results": top_api_results,
"retrieved_doc_ids": list(doc_scores.keys()),
"mapping_stats": mapping_stats,
"raw_api_metrics": raw_api_metrics,
}
)

Expand Down
19 changes: 18 additions & 1 deletion test/scripts/retrieval/test_eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from dingo.retrieval.eval_utils import compute_query_metrics, normalize_title, recall_at_k, resolve_hit, strip_d_prefix
from dingo.retrieval.eval_utils import build_raw_api_ranked_doc_ids, compute_query_metrics, normalize_title, recall_at_k, resolve_hit, strip_d_prefix


class TestNormalizeTitle:
Expand Down Expand Up @@ -95,6 +95,23 @@ def test_empty_retrieved(self):
assert metrics["recall_at_10"] == 0.0


class TestBuildRawApiRankedDocIds:
def test_preserves_unmatched_and_duplicate_positions(self):
ranked = build_raw_api_ranked_doc_ids(
[
{"rank": 1, "resolved_corpus_id": ""},
{"rank": 2, "resolved_corpus_id": "d1"},
{"rank": 3, "resolved_corpus_id": "d1"},
{"rank": 4, "resolved_corpus_id": "d2"},
]
)

assert ranked[0].startswith("__raw_api_unmatched__")
assert ranked[1] == "d1"
assert ranked[2].startswith("__raw_api_duplicate__")
assert ranked[3] == "d2"


class TestResolveHit:
def setup_method(self):
self.corpus_ids = {"d123", "d456", "d789"}
Expand Down
16 changes: 16 additions & 0 deletions test/scripts/retrieval/test_retrieval_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,23 @@ def fake_evaluate(model, tasks, overwrite_strategy):
"qid": "q1",
"retrieved_doc_ids": ["d1", "d2"],
"gold_doc_ids": ["d1"],
"api_results_count": 3,
"raw_api_metrics": {
"raw_api_ndcg_at_10": 0.5,
"raw_api_recall_at_10": 1.0,
"raw_api_mrr_at_10": 0.5,
},
},
{
"qid": "q2",
"retrieved_doc_ids": ["d3"],
"gold_doc_ids": ["d4"],
"api_results_count": 1,
"raw_api_metrics": {
"raw_api_ndcg_at_10": 0.0,
"raw_api_recall_at_10": 0.0,
"raw_api_mrr_at_10": 0.0,
},
},
],
})
Expand All @@ -183,3 +195,7 @@ def fake_evaluate(model, tasks, overwrite_strategy):
assert summary.metrics_score_stats["SciFact"]["main_score"] == 0.5
assert summary.metrics_score_stats["SciFact"]["ndcg_at_10"] == 0.5
assert summary.metrics_score_stats["SciFact"]["recall_at_10"] == 0.5
assert summary.metrics_score_stats["SciFact"]["raw_api_ndcg_at_10"] == 0.25
assert summary.metrics_score_stats["SciFact"]["raw_api_recall_at_10"] == 0.5
assert summary.metrics_score_stats["SciFact"]["raw_api_mrr_at_10"] == 0.25
assert summary.metrics_score_stats["SciFact"]["raw_api_avg_results_count"] == 2.0
Loading