Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dingo/config/input_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class RetrievalArgs(BaseModel):
freshness_boost: Optional[str] = None
filters: Optional[List[Dict[str, Any]] | Dict[str, Any]] = None
max_queries: Optional[int] = None
title_fuzzy_enabled: bool = False
title_fuzzy_threshold: float = 0.95
title_fuzzy_margin: float = 0.01
title_fuzzy_min_len: int = 20
title_fuzzy_max_candidates: int = 300
timeout: float = 120.0
rate_limit: Optional[float] = None
max_retries: int = 3
Expand Down
67 changes: 65 additions & 2 deletions dingo/exec/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dingo.config.input_args import InputArgs
from dingo.exec.base import Executor
from dingo.io import SummaryModel
from dingo.retrieval.eval_utils import make_output_dir, save_json
from dingo.retrieval.eval_utils import compute_query_metrics, make_output_dir, save_json
from dingo.retrieval.mteb_adapter import SearchClientModel
from dingo.retrieval.search_client import create_client

Expand Down Expand Up @@ -84,6 +84,11 @@ def execute(self) -> SummaryModel:
search_limit=ra.limit,
max_queries=ra.max_queries,
max_workers=ra.max_workers,
title_fuzzy_enabled=ra.title_fuzzy_enabled,
title_fuzzy_threshold=ra.title_fuzzy_threshold,
title_fuzzy_margin=ra.title_fuzzy_margin,
title_fuzzy_min_len=ra.title_fuzzy_min_len,
title_fuzzy_max_candidates=ra.title_fuzzy_max_candidates,
)

output_dir = make_output_dir(
Expand Down Expand Up @@ -118,9 +123,29 @@ def execute(self) -> SummaryModel:
overwrite_strategy="always",
)
task_metrics = self._extract_metrics(results)
if not task_metrics:
logger.warning(
"MTEB returned empty metrics for task %r; "
"falling back to search trace metrics",
task_name,
)
task_metrics = self._compute_metrics_from_search_traces(
model.get_search_traces(),
task_name,
)
all_results[task_name] = task_metrics
except Exception as e:
logger.error(f"Task {task_name!r} failed: {e}")
logger.error(f"Task {task_name!r} failed: {e}", exc_info=True)
task_metrics = self._compute_metrics_from_search_traces(
model.get_search_traces(),
task_name,
)
if task_metrics:
logger.warning(
"Using search trace fallback metrics for failed task %r",
task_name,
)
all_results[task_name] = task_metrics
continue

self._all_results = all_results
Expand All @@ -139,6 +164,11 @@ def execute(self) -> SummaryModel:
"limit": ra.limit,
"retrieval_mode": ra.retrieval_mode,
"sub_queries": ra.sub_queries,
"title_fuzzy_enabled": ra.title_fuzzy_enabled,
"title_fuzzy_threshold": ra.title_fuzzy_threshold,
"title_fuzzy_margin": ra.title_fuzzy_margin,
"title_fuzzy_min_len": ra.title_fuzzy_min_len,
"title_fuzzy_max_candidates": ra.title_fuzzy_max_candidates,
"max_queries": ra.max_queries,
"tasks": task_names,
}
Expand Down Expand Up @@ -257,6 +287,39 @@ def _extract_metrics(self, model_result) -> dict[str, float]:
metrics[key] = round(score_entry[key], 5)
return metrics

@staticmethod
def _compute_metrics_from_search_traces(
traces: list[dict[str, Any]],
task_name: str,
) -> dict[str, float]:
"""Compute fallback retrieval metrics from stored trace qrels/results."""
metric_values: dict[str, list[float]] = {}
for trace in traces:
if trace.get("task") != task_name:
continue
for query in trace.get("queries", []):
retrieved_doc_ids = query.get("retrieved_doc_ids") or []
gold_doc_ids = set(query.get("gold_doc_ids") or [])
if not gold_doc_ids:
continue

query_metrics = compute_query_metrics(
retrieved_doc_ids,
gold_doc_ids,
)
query_metrics["main_score"] = query_metrics.get("ndcg_at_10", 0.0)

for key in METRICS_OF_INTEREST:
if key not in query_metrics:
continue
metric_values.setdefault(key, []).append(query_metrics[key])

return {
key: round(sum(values) / len(values), 5)
for key, values in metric_values.items()
if values
}

def load_data(self):
pass

Expand Down
79 changes: 70 additions & 9 deletions dingo/retrieval/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import unicodedata
from datetime import datetime
from difflib import SequenceMatcher
from typing import Any

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,16 +55,18 @@ def compute_query_metrics(
retrieved_doc_ids: list[str],
relevant_doc_ids: set[str],
) -> dict[str, Any]:
"""Compute nDCG@10, MRR@10, Recall@{5,10,100,1000} for a single query."""
"""Compute standard retrieval metrics for a single query."""
top5 = retrieved_doc_ids[:5]
top10 = retrieved_doc_ids[:10]
top20 = retrieved_doc_ids[:20]
top100 = retrieved_doc_ids[:100]
top1000 = retrieved_doc_ids[:1000]

rel_flags_5 = [1 if did in relevant_doc_ids else 0 for did in top5]
rel_in_5 = sum(rel_flags_5)
rel_flags_10 = [1 if did in relevant_doc_ids else 0 for did in top10]
rel_in_10 = sum(rel_flags_10)
rel_flags_100 = [1 if did in relevant_doc_ids else 0 for did in top100]
rel_total = len(relevant_doc_ids)

first_rel_rank = -1
Expand All @@ -73,12 +76,21 @@ def compute_query_metrics(
break
mrr10 = 1.0 / first_rel_rank if first_rel_rank > 0 else 0.0

ideal_len = min(rel_total, 10)
idcg10 = dcg([1] * ideal_len, 10) if ideal_len > 0 else 0.0
ideal_len_10 = min(rel_total, 10)
idcg10 = dcg([1] * ideal_len_10, 10) if ideal_len_10 > 0 else 0.0
ndcg10 = (dcg(rel_flags_10, 10) / idcg10) if idcg10 > 0 else 0.0

ideal_len_100 = min(rel_total, 100)
idcg100 = dcg([1] * ideal_len_100, 100) if ideal_len_100 > 0 else 0.0
ndcg100 = (dcg(rel_flags_100, 100) / idcg100) if idcg100 > 0 else 0.0

recall5 = (rel_in_5 / rel_total) if rel_total > 0 else 0.0
recall10 = (rel_in_10 / rel_total) if rel_total > 0 else 0.0
recall20 = (
sum(1 for did in top20 if did in relevant_doc_ids) / rel_total
if rel_total > 0
else 0.0
)
recall100 = (
sum(1 for did in top100 if did in relevant_doc_ids) / rel_total
if rel_total > 0
Expand All @@ -90,42 +102,91 @@ def compute_query_metrics(
else 0.0
)

hits = 0
precision_sum = 0.0
for rank, did in enumerate(top10, start=1):
if did in relevant_doc_ids:
hits += 1
precision_sum += hits / rank
map10 = (
precision_sum / min(rel_total, 10)
if rel_total > 0
else 0.0
)

return {
"first_relevant_rank_at_10": first_rel_rank,
"relevant_in_top10": rel_in_10,
"relevant_total": rel_total,
"ndcg_at_10": round(ndcg10, 5),
"ndcg_at_100": round(ndcg100, 5),
"mrr_at_10": round(mrr10, 5),
"recall_at_5": round(recall5, 5),
"recall_at_10": round(recall10, 5),
"recall_at_20": round(recall20, 5),
"recall_at_100": round(recall100, 5),
"recall_at_1000": round(recall1000, 5),
"precision_at_10": round(rel_in_10 / 10, 5),
"map_at_10": round(map10, 5),
}


def resolve_hit(
hit: dict[str, Any],
title_index: dict[str, list[str]],
corpus_id_set: set[str],
) -> tuple[str, str]:
*,
title_fuzzy_enabled: bool = False,
title_fuzzy_threshold: float = 0.95,
title_fuzzy_margin: float = 0.01,
title_fuzzy_min_len: int = 20,
title_norm_candidates: list[tuple[str, list[str]]] | None = None,
) -> tuple[str, str, float | None]:
"""Resolve a search hit to a corpus ID.

Returns ``(corpus_id, mapping_source)`` where *mapping_source* is one of
``"doc_id_exact"``, ``"title_fallback"``, or ``"unmatched"``.
Returns ``(corpus_id, mapping_source, fuzzy_similarity)`` where
*mapping_source* is one of ``"doc_id_exact"``, ``"title_fallback"``,
``"title_fuzzy"``, or ``"unmatched"``.
"""
raw_id = str(hit.get("doc_id") or hit.get("paper_id") or "").strip()
if raw_id:
stripped = strip_d_prefix(raw_id)
for candidate in (raw_id, stripped, f"d{stripped}"):
if candidate in corpus_id_set:
return candidate, "doc_id_exact"
return candidate, "doc_id_exact", None
title = str(hit.get("title") or "")
norm = normalize_title(title)
if norm:
candidates = title_index.get(norm)
if candidates:
return candidates[0], "title_fallback"
return "", "unmatched"
return candidates[0], "title_fallback", None

if title_fuzzy_enabled and norm and len(norm) >= title_fuzzy_min_len:
iterable_candidates = (
title_norm_candidates
if title_norm_candidates is not None
else list(title_index.items())
)
best_ids: list[str] | None = None
best_score = -1.0
second_score = -1.0
for candidate_norm, candidate_ids in iterable_candidates:
score = SequenceMatcher(None, norm, candidate_norm).ratio()
if score > best_score:
second_score = best_score
best_score = score
best_ids = candidate_ids
elif score > second_score:
second_score = score
Comment on lines +173 to +180

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using SequenceMatcher.ratio() in a loop over hundreds of candidates for every search hit can be extremely slow and lead to significant performance bottlenecks. We can optimize this by leveraging real_quick_ratio() and quick_ratio(), which are fast-to-compute upper bounds of the true ratio. If these upper bounds are less than or equal to the current second_score, we can safely skip the expensive ratio() computation.

        for candidate_norm, candidate_ids in iterable_candidates:
            matcher = SequenceMatcher(None, norm, candidate_norm)
            if matcher.real_quick_ratio() <= second_score:
                continue
            if matcher.quick_ratio() <= second_score:
                continue
            score = matcher.ratio()
            if score > best_score:
                second_score = best_score
                best_score = score
                best_ids = candidate_ids
            elif score > second_score:
                second_score = score


if (
best_ids
and best_score >= title_fuzzy_threshold
and (best_score - second_score) >= title_fuzzy_margin
):
return best_ids[0], "title_fuzzy", round(best_score, 6)

return "", "unmatched", None


def make_output_dir(explicit_dir: str | None, default_prefix: str) -> str:
Expand Down
Loading
Loading