From 8f5f06e8c9d938998929b7a2c9c81e5b5ff5a9bb Mon Sep 17 00:00:00 2001 From: chupei Date: Tue, 16 Jun 2026 15:09:06 +0800 Subject: [PATCH] feat: retrival add raw api metrics calculate --- dingo/exec/retrieval.py | 47 +++++++++++++++++++ dingo/retrieval/eval_utils.py | 28 +++++++++++ dingo/retrieval/mteb_adapter.py | 43 ++++++++++++++++- test/scripts/retrieval/test_eval_utils.py | 19 +++++++- .../retrieval/test_retrieval_executor.py | 16 +++++++ 5 files changed, 150 insertions(+), 3 deletions(-) diff --git a/dingo/exec/retrieval.py b/dingo/exec/retrieval.py index ce925333..1837c4da 100644 --- a/dingo/exec/retrieval.py +++ b/dingo/exec/retrieval.py @@ -37,6 +37,10 @@ "map_at_10", ] +RAW_API_METRICS_OF_INTEREST = [ + f"raw_api_{key}" for key in METRICS_OF_INTEREST if key != "main_score" +] + @Executor.register("retrieval") class RetrievalExecutor: @@ -133,6 +137,12 @@ def execute(self) -> SummaryModel: model.get_search_traces(), task_name, ) + task_metrics.update( + self._compute_raw_api_metrics_from_search_traces( + model.get_search_traces(), + task_name, + ) + ) all_results[task_name] = task_metrics except Exception as e: logger.error(f"Task {task_name!r} failed: {e}", exc_info=True) @@ -141,6 +151,12 @@ def execute(self) -> SummaryModel: task_name, ) if task_metrics: + task_metrics.update( + self._compute_raw_api_metrics_from_search_traces( + model.get_search_traces(), + task_name, + ) + ) logger.warning( "Using search trace fallback metrics for failed task %r", task_name, @@ -320,6 +336,37 @@ def _compute_metrics_from_search_traces( if values } + @staticmethod + def _compute_raw_api_metrics_from_search_traces( + traces: list[dict[str, Any]], + task_name: str, + ) -> dict[str, float]: + metric_values: dict[str, list[float]] = {} + api_results_counts: list[float] = [] + for trace in traces: + if trace.get("task") != task_name: + continue + for query in trace.get("queries", []): + raw_metrics = query.get("raw_api_metrics") or {} + for key in RAW_API_METRICS_OF_INTEREST: + if key not in raw_metrics: + continue + metric_values.setdefault(key, []).append(raw_metrics[key]) + if "api_results_count" in query: + api_results_counts.append(float(query.get("api_results_count") or 0)) + + summary = { + key: round(sum(values) / len(values), 5) + for key, values in metric_values.items() + if values + } + if api_results_counts: + summary["raw_api_avg_results_count"] = round( + sum(api_results_counts) / len(api_results_counts), + 5, + ) + return summary + def load_data(self): pass diff --git a/dingo/retrieval/eval_utils.py b/dingo/retrieval/eval_utils.py index bff6a04f..c325c6e6 100644 --- a/dingo/retrieval/eval_utils.py +++ b/dingo/retrieval/eval_utils.py @@ -131,6 +131,34 @@ def compute_query_metrics( } +def build_raw_api_ranked_doc_ids( + top_api_results: list[dict[str, Any]], +) -> list[str]: + """Build a ranked doc-id list that preserves raw API ranks for metrics. + + Unmatched hits and duplicate resolved corpus IDs are kept as unique + non-relevant placeholders so later API hits do not move upward. + """ + ranked: list[str] = [] + seen_resolved_ids: set[str] = set() + for index, result in enumerate(top_api_results, start=1): + rank = result.get("rank") or index + resolved_id = str(result.get("resolved_corpus_id") or "").strip() + if not resolved_id: + ranked.append(f"__raw_api_unmatched__{rank}") + continue + if resolved_id in seen_resolved_ids: + ranked.append(f"__raw_api_duplicate__{rank}") + continue + seen_resolved_ids.add(resolved_id) + ranked.append(resolved_id) + return ranked + + +def prefix_metrics(metrics: dict[str, Any], prefix: str) -> dict[str, Any]: + return {f"{prefix}{key}": value for key, value in metrics.items()} + + def resolve_hit( hit: dict[str, Any], title_index: dict[str, list[str]], diff --git a/dingo/retrieval/mteb_adapter.py b/dingo/retrieval/mteb_adapter.py index a6421bd4..1a75dd5a 100644 --- a/dingo/retrieval/mteb_adapter.py +++ b/dingo/retrieval/mteb_adapter.py @@ -26,7 +26,7 @@ from mteb.models.model_meta import ModelMeta from tqdm import tqdm -from dingo.retrieval.eval_utils import normalize_title, resolve_hit +from dingo.retrieval.eval_utils import build_raw_api_ranked_doc_ids, compute_query_metrics, normalize_title, prefix_metrics, resolve_hit from dingo.retrieval.search_client import SearchClient, SearchResponse if TYPE_CHECKING: @@ -335,6 +335,7 @@ def _process_query(idx_qid_text): None, None, None, + None, ) if response.error: @@ -349,6 +350,7 @@ def _process_query(idx_qid_text): None, None, None, + None, ) doc_scores: dict[str, float] = {} @@ -409,6 +411,17 @@ def _process_query(idx_qid_text): if relevant_doc_ids is not None else None ) + raw_api_metrics = ( + prefix_metrics( + compute_query_metrics( + build_raw_api_ranked_doc_ids(top_api_results), + relevant_doc_ids, + ), + "raw_api_", + ) + if relevant_doc_ids is not None + else None + ) return ( idx, @@ -421,6 +434,7 @@ def _process_query(idx_qid_text): top_api_results, mapping_stats, relevant_matched_count, + raw_api_metrics, ) items = [ @@ -451,9 +465,23 @@ def _process_query(idx_qid_text): top_api_results, mapping_stats, relevant_matched_count, + raw_api_metrics, ) = future.result() if doc_scores is None: + relevant_doc_ids = ( + relevant_docs_by_qid.get(str(qid)) + if relevant_docs_by_qid is not None + else None + ) + raw_api_metrics = ( + prefix_metrics( + compute_query_metrics([], relevant_doc_ids), + "raw_api_", + ) + if relevant_doc_ids is not None + else None + ) errors += 1 logger.warning( f"[{idx + 1}/{total}] {qid} ERROR: {response.error}" @@ -475,7 +503,17 @@ def _process_query(idx_qid_text): "matched_count": 0, "mapped_count": 0, "relevant_matched_count": 0, - "relevant_total": 0, + "relevant_total": ( + len(relevant_doc_ids) + if relevant_doc_ids is not None + else 0 + ), + "gold_doc_ids": ( + sorted(relevant_doc_ids) + if relevant_doc_ids is not None + else None + ), + "raw_api_metrics": raw_api_metrics, } ) else: @@ -520,6 +558,7 @@ def _process_query(idx_qid_text): "top_api_results": top_api_results, "retrieved_doc_ids": list(doc_scores.keys()), "mapping_stats": mapping_stats, + "raw_api_metrics": raw_api_metrics, } ) diff --git a/test/scripts/retrieval/test_eval_utils.py b/test/scripts/retrieval/test_eval_utils.py index ffb3800b..2e01dea6 100644 --- a/test/scripts/retrieval/test_eval_utils.py +++ b/test/scripts/retrieval/test_eval_utils.py @@ -2,7 +2,7 @@ import pytest -from dingo.retrieval.eval_utils import compute_query_metrics, normalize_title, recall_at_k, resolve_hit, strip_d_prefix +from dingo.retrieval.eval_utils import build_raw_api_ranked_doc_ids, compute_query_metrics, normalize_title, recall_at_k, resolve_hit, strip_d_prefix class TestNormalizeTitle: @@ -95,6 +95,23 @@ def test_empty_retrieved(self): assert metrics["recall_at_10"] == 0.0 +class TestBuildRawApiRankedDocIds: + def test_preserves_unmatched_and_duplicate_positions(self): + ranked = build_raw_api_ranked_doc_ids( + [ + {"rank": 1, "resolved_corpus_id": ""}, + {"rank": 2, "resolved_corpus_id": "d1"}, + {"rank": 3, "resolved_corpus_id": "d1"}, + {"rank": 4, "resolved_corpus_id": "d2"}, + ] + ) + + assert ranked[0].startswith("__raw_api_unmatched__") + assert ranked[1] == "d1" + assert ranked[2].startswith("__raw_api_duplicate__") + assert ranked[3] == "d2" + + class TestResolveHit: def setup_method(self): self.corpus_ids = {"d123", "d456", "d789"} diff --git a/test/scripts/retrieval/test_retrieval_executor.py b/test/scripts/retrieval/test_retrieval_executor.py index 9341d278..c3513db6 100644 --- a/test/scripts/retrieval/test_retrieval_executor.py +++ b/test/scripts/retrieval/test_retrieval_executor.py @@ -163,11 +163,23 @@ def fake_evaluate(model, tasks, overwrite_strategy): "qid": "q1", "retrieved_doc_ids": ["d1", "d2"], "gold_doc_ids": ["d1"], + "api_results_count": 3, + "raw_api_metrics": { + "raw_api_ndcg_at_10": 0.5, + "raw_api_recall_at_10": 1.0, + "raw_api_mrr_at_10": 0.5, + }, }, { "qid": "q2", "retrieved_doc_ids": ["d3"], "gold_doc_ids": ["d4"], + "api_results_count": 1, + "raw_api_metrics": { + "raw_api_ndcg_at_10": 0.0, + "raw_api_recall_at_10": 0.0, + "raw_api_mrr_at_10": 0.0, + }, }, ], }) @@ -183,3 +195,7 @@ def fake_evaluate(model, tasks, overwrite_strategy): assert summary.metrics_score_stats["SciFact"]["main_score"] == 0.5 assert summary.metrics_score_stats["SciFact"]["ndcg_at_10"] == 0.5 assert summary.metrics_score_stats["SciFact"]["recall_at_10"] == 0.5 + assert summary.metrics_score_stats["SciFact"]["raw_api_ndcg_at_10"] == 0.25 + assert summary.metrics_score_stats["SciFact"]["raw_api_recall_at_10"] == 0.5 + assert summary.metrics_score_stats["SciFact"]["raw_api_mrr_at_10"] == 0.25 + assert summary.metrics_score_stats["SciFact"]["raw_api_avg_results_count"] == 2.0